diff --git a/cloudbaseinit/plugins/common/execcmd.py b/cloudbaseinit/plugins/common/execcmd.py
index c6660f49..e0f0c56a 100644
--- a/cloudbaseinit/plugins/common/execcmd.py
+++ b/cloudbaseinit/plugins/common/execcmd.py
@@ -12,11 +12,14 @@
# License for the specific language governing permissions and limitations
# under the License.
+
import functools
import os
+import re
import tempfile
import uuid
+from cloudbaseinit.openstack.common import log as logging
from cloudbaseinit.osutils import factory as osutils_factory
@@ -27,8 +30,60 @@ __all__ = (
'Bash',
'Powershell',
'PowershellSysnative',
+ 'CommandExecutor',
+ 'EC2Config',
)
+LOG = logging.getLogger(__name__)
+
+# used with ec2 config files (xmls)
+SCRIPT_TAG = 1
+POWERSHELL_TAG = 2
+# regexp and temporary file extension for each tag
+TAG_REGEX = {
+ SCRIPT_TAG: (
+ re.compile(br""),
+ "cmd"
+ ),
+ POWERSHELL_TAG: (
+ re.compile(br"([\s\S]+?)"),
+ "ps1"
+ )
+}
+
+
+def _ec2_find_sections(data):
+ """An intuitive script generator.
+
+ Is able to detect and extract code between:
+ -
+ - ...
+ tags. Yields data with each specific block of code.
+ Note that, regardless of data structure, all cmd scripts are
+ yielded before the rest of powershell scripts.
+ """
+ # extract code blocks between the tags
+ blocks = {
+ SCRIPT_TAG: TAG_REGEX[SCRIPT_TAG][0].findall(data),
+ POWERSHELL_TAG: TAG_REGEX[POWERSHELL_TAG][0].findall(data)
+ }
+ # build and yield blocks (preserve order)
+ for script_type in (SCRIPT_TAG, POWERSHELL_TAG):
+ for code in blocks[script_type]:
+ code = code.strip()
+ if not code:
+ continue # skip the empty ones
+ yield code, script_type
+
+
+def _split_sections(multicmd):
+ for code, stype in _ec2_find_sections(multicmd):
+ if stype == SCRIPT_TAG:
+ command = Shell.from_data(code)
+ else:
+ command = PowershellSysnative.from_data(code)
+ yield command
+
class BaseCommand(object):
"""Implements logic for executing an user command.
@@ -143,3 +198,51 @@ class PowershellSysnative(BaseCommand):
class Powershell(PowershellSysnative):
sysnative = False
+
+
+class CommandExecutor(object):
+
+ """Execute multiple commands and gather outputs."""
+
+ SEP = b"\n" # multistring separator
+
+ def __init__(self, commands):
+ self._commands = commands
+
+ def execute(self):
+ out_total = []
+ err_total = []
+ ret_total = 0
+ for command in self._commands:
+ out = err = b""
+ ret_val = 0
+ try:
+ out, err, ret_val = command()
+ except Exception as exc:
+ LOG.exception(
+ "An error occurred during part execution: %s",
+ exc
+ )
+ else:
+ out_total.append(out)
+ err_total.append(err)
+ ret_total += ret_val
+ return (
+ self.SEP.join(out_total),
+ self.SEP.join(err_total),
+ ret_total
+ )
+
+ __call__ = execute
+
+
+class EC2Config(object):
+
+ @classmethod
+ def from_data(cls, multicmd):
+ """Create multiple `CommandExecutor` objects.
+
+ These are created using data chunks
+ parsed from the given command data.
+ """
+ return CommandExecutor(_split_sections(multicmd))
diff --git a/cloudbaseinit/plugins/windows/userdatautils.py b/cloudbaseinit/plugins/windows/userdatautils.py
index 233d95f8..230245a9 100644
--- a/cloudbaseinit/plugins/windows/userdatautils.py
+++ b/cloudbaseinit/plugins/windows/userdatautils.py
@@ -12,23 +12,26 @@
# License for the specific language governing permissions and limitations
# under the License.
+
import functools
import re
from cloudbaseinit.openstack.common import log as logging
from cloudbaseinit.plugins.common import execcmd
+
LOG = logging.getLogger(__name__)
# Avoid 80+ length by using a local variable, which
# is deleted afterwards.
_compile = functools.partial(re.compile, flags=re.I)
FORMATS = (
- (_compile(br'^rem cmd\s'), execcmd.Shell),
- (_compile(br'^#!/usr/bin/env\spython\s'), execcmd.Python),
+ (_compile(br'^rem\s+cmd\s'), execcmd.Shell),
+ (_compile(br'^#!\s*/usr/bin/env\s+python\s'), execcmd.Python),
(_compile(br'^#!'), execcmd.Bash),
(_compile(br'^#(ps1|ps1_sysnative)\s'), execcmd.PowershellSysnative),
(_compile(br'^#ps1_x86\s'), execcmd.Powershell),
+ (_compile(br'?(script|powershell)>'), execcmd.EC2Config),
)
del _compile
@@ -45,15 +48,14 @@ def execute_user_data_script(user_data):
out = err = None
command = _get_command(user_data)
if not command:
- # Unsupported
LOG.warning('Unsupported user_data format')
return ret_val
try:
out, err, ret_val = command()
- except Exception as ex:
+ except Exception as exc:
LOG.warning('An error occurred during user_data execution: \'%s\'',
- ex)
+ exc)
else:
LOG.debug('User_data stdout:\n%s', out)
LOG.debug('User_data stderr:\n%s', err)
diff --git a/cloudbaseinit/tests/plugins/common/test_execcmd.py b/cloudbaseinit/tests/plugins/common/test_execcmd.py
index 669180e3..409b2fc4 100644
--- a/cloudbaseinit/tests/plugins/common/test_execcmd.py
+++ b/cloudbaseinit/tests/plugins/common/test_execcmd.py
@@ -12,7 +12,9 @@
# License for the specific language governing permissions and limitations
# under the License.
+
import os
+import textwrap
import unittest
import mock
@@ -22,6 +24,8 @@ from cloudbaseinit.tests import testutils
def _remove_file(filepath):
+ if not filepath:
+ return
try:
os.remove(filepath)
except OSError:
@@ -29,7 +33,7 @@ def _remove_file(filepath):
@mock.patch('cloudbaseinit.osutils.factory.get_os_utils')
-class execcmdTest(unittest.TestCase):
+class TestExecCmd(unittest.TestCase):
def test_from_data(self, _):
command = execcmd.BaseCommand.from_data(b"test")
@@ -108,3 +112,51 @@ class execcmdTest(unittest.TestCase):
command.execute()
cleanup.assert_called_once_with()
+
+ @mock.patch("cloudbaseinit.plugins.common.execcmd.PowershellSysnative")
+ @mock.patch("cloudbaseinit.plugins.common.execcmd.Shell")
+ def _test_process_ec2(self, mock_shell, mock_psnative, tag=None):
+ if tag:
+ content = textwrap.dedent("""
+ <{0}>mocked{0}>
+
+ <{0}>second{0}>
+ 1
+ <{0}>third
+ {0}>
+ <{0}>{0}> # empty
+ <{0}>{0} # invalid
+ """.format(tag)).encode()
+ else:
+ content = textwrap.dedent("""
+ p1
+
+
+ p2
+
+ """).encode()
+
+ def ident(value):
+ ident_func = mock.MagicMock()
+ ident_func.return_value = (value, b"", 0)
+ return ident_func
+
+ mock_shell.from_data = ident
+ mock_psnative.from_data = ident
+
+ ec2conf = execcmd.EC2Config.from_data(content)
+ out, _, _ = ec2conf()
+
+ if tag:
+ self.assertEqual(b"mocked\nsecond\nthird", out)
+ else:
+ self.assertEqual(b"s1\ns2\ns3\np1\np2", out)
+
+ def test_process_ec2_script(self, _):
+ self._test_process_ec2(tag="script")
+
+ def test_process_ec2_powershell(self, _):
+ self._test_process_ec2(tag="powershell")
+
+ def test_process_ec2_order(self, _):
+ self._test_process_ec2()
diff --git a/cloudbaseinit/tests/plugins/windows/test_userdatautils.py b/cloudbaseinit/tests/plugins/windows/test_userdatautils.py
index 5cccee81..fa7df2cd 100644
--- a/cloudbaseinit/tests/plugins/windows/test_userdatautils.py
+++ b/cloudbaseinit/tests/plugins/windows/test_userdatautils.py
@@ -12,6 +12,7 @@
# License for the specific language governing permissions and limitations
# under the License.
+
import os
import unittest
@@ -22,6 +23,8 @@ from cloudbaseinit.plugins.windows import userdatautils
def _safe_remove(filepath):
+ if not filepath:
+ return
try:
os.remove(filepath)
except OSError:
@@ -38,7 +41,7 @@ class UserDataUtilsTest(unittest.TestCase):
to remove the underlying target path of the command.
"""
command = userdatautils._get_command(data)
- if command:
+ if command and not isinstance(command, execcmd.CommandExecutor):
self.addCleanup(_safe_remove, command._target_path)
return command
@@ -58,6 +61,9 @@ class UserDataUtilsTest(unittest.TestCase):
command = self._get_command(b'#ps1_x86\n')
self.assertIsInstance(command, execcmd.Powershell)
+ command = self._get_command(b'')
+ self.assertIsInstance(command, execcmd.CommandExecutor)
+
command = self._get_command(b'unknown')
self.assertIsNone(command)
diff --git a/cloudbaseinit/utils/encoding.py b/cloudbaseinit/utils/encoding.py
index 7a09e90d..b5585ecb 100644
--- a/cloudbaseinit/utils/encoding.py
+++ b/cloudbaseinit/utils/encoding.py
@@ -33,3 +33,13 @@ def write_file(target_path, data, mode='wb'):
with open(target_path, mode) as f:
f.write(data)
+
+
+def read_file(target_path, mode='rb'):
+ with open(target_path, mode) as f:
+ data = f.read()
+
+ if 'b' in mode:
+ data = data.decode()
+
+ return data