diff --git a/cloudbaseinit/plugins/common/execcmd.py b/cloudbaseinit/plugins/common/execcmd.py index c6660f49..e0f0c56a 100644 --- a/cloudbaseinit/plugins/common/execcmd.py +++ b/cloudbaseinit/plugins/common/execcmd.py @@ -12,11 +12,14 @@ # License for the specific language governing permissions and limitations # under the License. + import functools import os +import re import tempfile import uuid +from cloudbaseinit.openstack.common import log as logging from cloudbaseinit.osutils import factory as osutils_factory @@ -27,8 +30,60 @@ __all__ = ( 'Bash', 'Powershell', 'PowershellSysnative', + 'CommandExecutor', + 'EC2Config', ) +LOG = logging.getLogger(__name__) + +# used with ec2 config files (xmls) +SCRIPT_TAG = 1 +POWERSHELL_TAG = 2 +# regexp and temporary file extension for each tag +TAG_REGEX = { + SCRIPT_TAG: ( + re.compile(br""), + "cmd" + ), + POWERSHELL_TAG: ( + re.compile(br"([\s\S]+?)"), + "ps1" + ) +} + + +def _ec2_find_sections(data): + """An intuitive script generator. + + Is able to detect and extract code between: + - + - ... + tags. Yields data with each specific block of code. + Note that, regardless of data structure, all cmd scripts are + yielded before the rest of powershell scripts. + """ + # extract code blocks between the tags + blocks = { + SCRIPT_TAG: TAG_REGEX[SCRIPT_TAG][0].findall(data), + POWERSHELL_TAG: TAG_REGEX[POWERSHELL_TAG][0].findall(data) + } + # build and yield blocks (preserve order) + for script_type in (SCRIPT_TAG, POWERSHELL_TAG): + for code in blocks[script_type]: + code = code.strip() + if not code: + continue # skip the empty ones + yield code, script_type + + +def _split_sections(multicmd): + for code, stype in _ec2_find_sections(multicmd): + if stype == SCRIPT_TAG: + command = Shell.from_data(code) + else: + command = PowershellSysnative.from_data(code) + yield command + class BaseCommand(object): """Implements logic for executing an user command. @@ -143,3 +198,51 @@ class PowershellSysnative(BaseCommand): class Powershell(PowershellSysnative): sysnative = False + + +class CommandExecutor(object): + + """Execute multiple commands and gather outputs.""" + + SEP = b"\n" # multistring separator + + def __init__(self, commands): + self._commands = commands + + def execute(self): + out_total = [] + err_total = [] + ret_total = 0 + for command in self._commands: + out = err = b"" + ret_val = 0 + try: + out, err, ret_val = command() + except Exception as exc: + LOG.exception( + "An error occurred during part execution: %s", + exc + ) + else: + out_total.append(out) + err_total.append(err) + ret_total += ret_val + return ( + self.SEP.join(out_total), + self.SEP.join(err_total), + ret_total + ) + + __call__ = execute + + +class EC2Config(object): + + @classmethod + def from_data(cls, multicmd): + """Create multiple `CommandExecutor` objects. + + These are created using data chunks + parsed from the given command data. + """ + return CommandExecutor(_split_sections(multicmd)) diff --git a/cloudbaseinit/plugins/windows/userdatautils.py b/cloudbaseinit/plugins/windows/userdatautils.py index 233d95f8..230245a9 100644 --- a/cloudbaseinit/plugins/windows/userdatautils.py +++ b/cloudbaseinit/plugins/windows/userdatautils.py @@ -12,23 +12,26 @@ # License for the specific language governing permissions and limitations # under the License. + import functools import re from cloudbaseinit.openstack.common import log as logging from cloudbaseinit.plugins.common import execcmd + LOG = logging.getLogger(__name__) # Avoid 80+ length by using a local variable, which # is deleted afterwards. _compile = functools.partial(re.compile, flags=re.I) FORMATS = ( - (_compile(br'^rem cmd\s'), execcmd.Shell), - (_compile(br'^#!/usr/bin/env\spython\s'), execcmd.Python), + (_compile(br'^rem\s+cmd\s'), execcmd.Shell), + (_compile(br'^#!\s*/usr/bin/env\s+python\s'), execcmd.Python), (_compile(br'^#!'), execcmd.Bash), (_compile(br'^#(ps1|ps1_sysnative)\s'), execcmd.PowershellSysnative), (_compile(br'^#ps1_x86\s'), execcmd.Powershell), + (_compile(br''), execcmd.EC2Config), ) del _compile @@ -45,15 +48,14 @@ def execute_user_data_script(user_data): out = err = None command = _get_command(user_data) if not command: - # Unsupported LOG.warning('Unsupported user_data format') return ret_val try: out, err, ret_val = command() - except Exception as ex: + except Exception as exc: LOG.warning('An error occurred during user_data execution: \'%s\'', - ex) + exc) else: LOG.debug('User_data stdout:\n%s', out) LOG.debug('User_data stderr:\n%s', err) diff --git a/cloudbaseinit/tests/plugins/common/test_execcmd.py b/cloudbaseinit/tests/plugins/common/test_execcmd.py index 669180e3..409b2fc4 100644 --- a/cloudbaseinit/tests/plugins/common/test_execcmd.py +++ b/cloudbaseinit/tests/plugins/common/test_execcmd.py @@ -12,7 +12,9 @@ # License for the specific language governing permissions and limitations # under the License. + import os +import textwrap import unittest import mock @@ -22,6 +24,8 @@ from cloudbaseinit.tests import testutils def _remove_file(filepath): + if not filepath: + return try: os.remove(filepath) except OSError: @@ -29,7 +33,7 @@ def _remove_file(filepath): @mock.patch('cloudbaseinit.osutils.factory.get_os_utils') -class execcmdTest(unittest.TestCase): +class TestExecCmd(unittest.TestCase): def test_from_data(self, _): command = execcmd.BaseCommand.from_data(b"test") @@ -108,3 +112,51 @@ class execcmdTest(unittest.TestCase): command.execute() cleanup.assert_called_once_with() + + @mock.patch("cloudbaseinit.plugins.common.execcmd.PowershellSysnative") + @mock.patch("cloudbaseinit.plugins.common.execcmd.Shell") + def _test_process_ec2(self, mock_shell, mock_psnative, tag=None): + if tag: + content = textwrap.dedent(""" + <{0}>mocked + + <{0}>second + 1 + <{0}>third + + <{0}> # empty + <{0}>p1 + + + p2 + + """).encode() + + def ident(value): + ident_func = mock.MagicMock() + ident_func.return_value = (value, b"", 0) + return ident_func + + mock_shell.from_data = ident + mock_psnative.from_data = ident + + ec2conf = execcmd.EC2Config.from_data(content) + out, _, _ = ec2conf() + + if tag: + self.assertEqual(b"mocked\nsecond\nthird", out) + else: + self.assertEqual(b"s1\ns2\ns3\np1\np2", out) + + def test_process_ec2_script(self, _): + self._test_process_ec2(tag="script") + + def test_process_ec2_powershell(self, _): + self._test_process_ec2(tag="powershell") + + def test_process_ec2_order(self, _): + self._test_process_ec2() diff --git a/cloudbaseinit/tests/plugins/windows/test_userdatautils.py b/cloudbaseinit/tests/plugins/windows/test_userdatautils.py index 5cccee81..fa7df2cd 100644 --- a/cloudbaseinit/tests/plugins/windows/test_userdatautils.py +++ b/cloudbaseinit/tests/plugins/windows/test_userdatautils.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. + import os import unittest @@ -22,6 +23,8 @@ from cloudbaseinit.plugins.windows import userdatautils def _safe_remove(filepath): + if not filepath: + return try: os.remove(filepath) except OSError: @@ -38,7 +41,7 @@ class UserDataUtilsTest(unittest.TestCase): to remove the underlying target path of the command. """ command = userdatautils._get_command(data) - if command: + if command and not isinstance(command, execcmd.CommandExecutor): self.addCleanup(_safe_remove, command._target_path) return command @@ -58,6 +61,9 @@ class UserDataUtilsTest(unittest.TestCase): command = self._get_command(b'#ps1_x86\n') self.assertIsInstance(command, execcmd.Powershell) + command = self._get_command(b'') + self.assertIsInstance(command, execcmd.CommandExecutor) + command = self._get_command(b'unknown') self.assertIsNone(command) diff --git a/cloudbaseinit/utils/encoding.py b/cloudbaseinit/utils/encoding.py index 7a09e90d..b5585ecb 100644 --- a/cloudbaseinit/utils/encoding.py +++ b/cloudbaseinit/utils/encoding.py @@ -33,3 +33,13 @@ def write_file(target_path, data, mode='wb'): with open(target_path, mode) as f: f.write(data) + + +def read_file(target_path, mode='rb'): + with open(target_path, mode) as f: + data = f.read() + + if 'b' in mode: + data = data.decode() + + return data