Improve MIME multipart userdata handling

The 'exec_file' will now try to check for the headers inside the file, in
case that it doesn't recognize the format or no format has been provided.
Also the processing part of the user data now checks if the 'Content-Type' is
in the file instead of checking if the file starts with the header, in
order to comply with RFC2045.

Change-Id: I53fda9f5c17f35cb35d93a86434ecc4c7c579802
Closes-Bug: #1623393
Closes-Bug: #1672222
This commit is contained in:
Stefan Caraiman 2016-09-21 09:47:33 +03:00
parent 122a5c58bc
commit 8e5fff7016
No known key found for this signature in database
GPG Key ID: F61A336E829641BA
6 changed files with 113 additions and 74 deletions

View File

@ -12,35 +12,23 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import os
from oslo_log import log as oslo_logging from oslo_log import log as oslo_logging
from cloudbaseinit.plugins.common import execcmd from cloudbaseinit.plugins.common import userdatautils
LOG = oslo_logging.getLogger(__name__) LOG = oslo_logging.getLogger(__name__)
FORMATS = {
"cmd": execcmd.Shell,
"exe": execcmd.Shell,
"sh": execcmd.Bash,
"py": execcmd.Python,
"ps1": execcmd.PowershellSysnative,
}
def exec_file(file_path): def exec_file(file_path):
ret_val = 0 ret_val = 0
ext = os.path.splitext(file_path)[1][1:].lower() command = userdatautils.get_command_from_path(file_path)
command = FORMATS.get(ext)
if not command: if not command:
# Unsupported # File format not provided or not recognized
LOG.warning('Unsupported script file type: %s', ext) LOG.debug('No valid extension or header found in the '
'userdata: %s' % file_path)
return ret_val return ret_val
try: try:
out, err, ret_val = command(file_path).execute() out, err, ret_val = command.execute()
except Exception as ex: except Exception as ex:
LOG.warning('An error occurred during file execution: \'%s\'', ex) LOG.warning('An error occurred during file execution: \'%s\'', ex)
else: else:

View File

@ -83,11 +83,29 @@ class UserDataPlugin(base.BasePlugin):
LOG.debug('User data content:\n%s', user_data_str) LOG.debug('User data content:\n%s', user_data_str)
return email.message_from_string(user_data_str).walk() return email.message_from_string(user_data_str).walk()
@staticmethod
def _get_headers(user_data):
"""Returns the header of the given user data.
:param user_data: Represents the content of the user data.
:rtype: A string chunk containing the header or None.
.. note :: In case the content type is not valid,
None will be returned.
"""
content = encoding.get_as_string(user_data)
if content:
return content.split("\n\n")[0]
else:
raise exception.CloudbaseInitException("No header could be found."
"The user data content is "
"either invalid or empty.")
def _process_user_data(self, user_data): def _process_user_data(self, user_data):
plugin_status = base.PLUGIN_EXECUTION_DONE plugin_status = base.PLUGIN_EXECUTION_DONE
reboot = False reboot = False
LOG.debug("Processing userdata") headers = self._get_headers(user_data)
if user_data.startswith(b'Content-Type: multipart'): if 'Content-Type: multipart' in headers:
LOG.debug("Processing userdata")
user_data_plugins = factory.load_plugins() user_data_plugins = factory.load_plugins()
user_handlers = {} user_handlers = {}

View File

@ -12,8 +12,8 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import collections
import functools import os
import re import re
from oslo_log import log as oslo_logging from oslo_log import log as oslo_logging
@ -23,37 +23,69 @@ from cloudbaseinit.plugins.common import execcmd
LOG = oslo_logging.getLogger(__name__) LOG = oslo_logging.getLogger(__name__)
# Avoid 80+ length by using a local variable, which _Script = collections.namedtuple('Script', ['extension', 'script_type',
# is deleted afterwards. 'executor'])
_compile = functools.partial(re.compile, flags=re.I) _SCRIPTS = (
FORMATS = ( _Script(extension='cmd', executor=execcmd.Shell,
(_compile(br'^rem\s+cmd\s'), execcmd.Shell), script_type=re.compile(br'^rem\s+cmd\s')),
(_compile(br'^#!\s*/usr/bin/env\s+python\s'), execcmd.Python), _Script(script_type=re.compile(br'^#!\s*/usr/bin/env\s+python\s'),
(_compile(br'^#!'), execcmd.Bash), extension='py', executor=execcmd.Python),
(_compile(br'^#(ps1|ps1_sysnative)\s'), execcmd.PowershellSysnative), _Script(extension='exe', script_type=None, executor=execcmd.Shell),
(_compile(br'^#ps1_x86\s'), execcmd.Powershell), _Script(extension='sh', script_type=re.compile(br'^#!'),
(_compile(br'</?(script|powershell)>'), execcmd.EC2Config), executor=execcmd.Bash),
) _Script(extension='ps1', executor=execcmd.PowershellSysnative,
del _compile script_type=re.compile(br'^#(ps1|ps1_sysnative)\s')),
_Script(extension=None, executor=execcmd.Powershell,
script_type=re.compile(br'^#ps1_x86\s')),
_Script(extension=None, executor=execcmd.EC2Config,
script_type=re.compile(br'</?(script|powershell)>')))
def _get_command(data): def _get_command(data, is_path=False):
# Get the command which should process the given data. """Returns a specific command executor if the data type is found.
for pattern, command_class in FORMATS:
if pattern.search(data): :param data: It can be either a file or content of user_data type.
return command_class.from_data(data) :param is_path: Determines whether :data: is a file path or it
contains the user_data content.
:rtype: An `execcmd` command type or `None`.
.. note :: In case the data doesn't have a valid extension or
header, it will return `None`.
"""
if is_path:
extension = os.path.splitext(data)[1][1:].lower()
for script in _SCRIPTS:
if extension == script.extension:
return script.executor(data)
with open(data, 'rb') as file_handler:
file_handler.seek(0)
user_data = file_handler.read()
else:
user_data = data
for script in _SCRIPTS:
if script.script_type and script.script_type.search(user_data):
return script.executor.from_data(user_data)
return None
def get_command(data):
return _get_command(data)
def get_command_from_path(path):
return _get_command(path, is_path=True)
def execute_user_data_script(user_data): def execute_user_data_script(user_data):
ret_val = 0 ret_val = 0
out = err = None out = err = None
command = _get_command(user_data) command = get_command(user_data)
if not command: if not command:
LOG.warning('Unsupported user_data format') LOG.warning('Unsupported user_data format')
return ret_val return ret_val
try: try:
out, err, ret_val = command() out, err, ret_val = command.execute()
except Exception as exc: except Exception as exc:
LOG.warning('An error occurred during user_data execution: \'%s\'', LOG.warning('An error occurred during user_data execution: \'%s\'',
exc) exc)

View File

@ -19,7 +19,6 @@ try:
except ImportError: except ImportError:
import mock import mock
from cloudbaseinit.plugins.common import execcmd
from cloudbaseinit.plugins.common import fileexecutils from cloudbaseinit.plugins.common import fileexecutils
from cloudbaseinit.tests import testutils from cloudbaseinit.tests import testutils
@ -27,27 +26,25 @@ from cloudbaseinit.tests import testutils
@mock.patch('cloudbaseinit.osutils.factory.get_os_utils') @mock.patch('cloudbaseinit.osutils.factory.get_os_utils')
class TestFileExecutilsPlugin(unittest.TestCase): class TestFileExecutilsPlugin(unittest.TestCase):
def test_exec_file_no_executor(self, _): @mock.patch('cloudbaseinit.plugins.common.userdatautils.'
with testutils.LogSnatcher('cloudbaseinit.plugins.common.' 'get_command_from_path')
'fileexecutils') as snatcher: @mock.patch('cloudbaseinit.plugins.common.userdatautils.'
retval = fileexecutils.exec_file("fake.fake") 'execute_user_data_script')
def test_exec_file_no_executor(self, mock_execute_user_data_script,
mock_get_command, _):
mock_get_command.return_value = None
with testutils.create_tempfile() as temp:
with mock.patch('cloudbaseinit.plugins.common.userdatautils'
'.open', create=True):
with testutils.LogSnatcher('cloudbaseinit.plugins.common.'
'fileexecutils') as snatcher:
retval = fileexecutils.exec_file(temp)
expected_logging = ['Unsupported script file type: fake'] expected_logging = ['No valid extension or header found'
' in the userdata: %s' % temp]
self.assertEqual(0, retval) self.assertEqual(0, retval)
self.assertEqual(expected_logging, snatcher.output) self.assertEqual(expected_logging, snatcher.output)
def test_executors_mapping(self, _):
self.assertEqual(fileexecutils.FORMATS["cmd"],
execcmd.Shell)
self.assertEqual(fileexecutils.FORMATS["exe"],
execcmd.Shell)
self.assertEqual(fileexecutils.FORMATS["sh"],
execcmd.Bash)
self.assertEqual(fileexecutils.FORMATS["py"],
execcmd.Python)
self.assertEqual(fileexecutils.FORMATS["ps1"],
execcmd.PowershellSysnative)
@mock.patch('cloudbaseinit.plugins.common.execcmd.' @mock.patch('cloudbaseinit.plugins.common.execcmd.'
'BaseCommand.execute') 'BaseCommand.execute')
def test_exec_file_fails(self, mock_execute, _): def test_exec_file_fails(self, mock_execute, _):
@ -67,12 +64,7 @@ class TestFileExecutilsPlugin(unittest.TestCase):
@mock.patch('cloudbaseinit.plugins.common.execcmd.' @mock.patch('cloudbaseinit.plugins.common.execcmd.'
'BaseCommand.execute') 'BaseCommand.execute')
def test_exec_file_(self, mock_execute, _): def test_exec_file_(self, mock_execute, _):
mock_execute.return_value = ( mock_execute.return_value = (mock.sentinel.out, mock.sentinel.error, 0)
mock.sentinel.out,
mock.sentinel.error,
0,
)
retval = fileexecutils.exec_file("fake.py") retval = fileexecutils.exec_file("fake.py")
mock_execute.assert_called_once_with() mock_execute.assert_called_once_with()
self.assertEqual(0, retval) self.assertEqual(0, retval)

View File

@ -138,6 +138,13 @@ class UserDataPluginTest(unittest.TestCase):
self.assertEqual(response, mock_message_from_string().walk()) self.assertEqual(response, mock_message_from_string().walk())
self.assertEqual(expected_logging, snatcher.output) self.assertEqual(expected_logging, snatcher.output)
def test_get_header(self):
fake_data = "fake-user-data"
self.assertEqual(fake_data, self._userdata._get_headers(fake_data))
fake_data = None
with self.assertRaises(exception.CloudbaseInitException):
self._userdata._get_headers(fake_data)
@mock.patch('cloudbaseinit.plugins.common.userdataplugins.factory.' @mock.patch('cloudbaseinit.plugins.common.userdataplugins.factory.'
'load_plugins') 'load_plugins')
@mock.patch('cloudbaseinit.plugins.common.userdata.UserDataPlugin' @mock.patch('cloudbaseinit.plugins.common.userdata.UserDataPlugin'

View File

@ -44,12 +44,12 @@ class UserDataUtilsTest(unittest.TestCase):
If a command was obtained, then a cleanup will be added in order If a command was obtained, then a cleanup will be added in order
to remove the underlying target path of the command. to remove the underlying target path of the command.
""" """
command = userdatautils._get_command(data) command = userdatautils.get_command(data)
if command and not isinstance(command, execcmd.CommandExecutor): if command and not isinstance(command, execcmd.CommandExecutor):
self.addCleanup(_safe_remove, command._target_path) self.addCleanup(_safe_remove, command._target_path)
return command return command
def test__get_command(self, _): def test_get_command(self, _):
command = self._get_command(b'rem cmd test') command = self._get_command(b'rem cmd test')
self.assertIsInstance(command, execcmd.Shell) self.assertIsInstance(command, execcmd.Shell)
@ -83,10 +83,11 @@ class UserDataUtilsTest(unittest.TestCase):
self.assertEqual(expected_logging, snatcher.output) self.assertEqual(expected_logging, snatcher.output)
@mock.patch('cloudbaseinit.plugins.common.userdatautils.' @mock.patch('cloudbaseinit.plugins.common.userdatautils.'
'_get_command') 'get_command')
def test_execute_user_data_script_fails(self, mock_get_command, _): def test_execute_user_data_script_fails(self, mock_get_command, _):
mock_get_command.return_value.side_effect = ValueError mock_command = mock.Mock()
mock_command.execute.side_effect = ValueError
mock_get_command.return_value = mock_command
with testutils.LogSnatcher('cloudbaseinit.plugins.common.' with testutils.LogSnatcher('cloudbaseinit.plugins.common.'
'userdatautils') as snatcher: 'userdatautils') as snatcher:
retval = userdatautils.execute_user_data_script( retval = userdatautils.execute_user_data_script(
@ -94,17 +95,18 @@ class UserDataUtilsTest(unittest.TestCase):
expected_logging = [ expected_logging = [
"An error occurred during user_data execution: ''", "An error occurred during user_data execution: ''",
'User_data script ended with return code: 0' 'User_data script ended with return code: 0']
]
self.assertEqual(0, retval) self.assertEqual(0, retval)
self.assertEqual(expected_logging, snatcher.output) self.assertEqual(expected_logging, snatcher.output)
@mock.patch('cloudbaseinit.plugins.common.userdatautils.' @mock.patch('cloudbaseinit.plugins.common.userdatautils.'
'_get_command') 'get_command')
def test_execute_user_data_script(self, mock_get_command, _): def test_execute_user_data_script(self, mock_get_command, _):
mock_get_command.return_value.return_value = ( mock_command = mock.Mock()
mock_command.execute.return_value = (
mock.sentinel.output, mock.sentinel.error, -1 mock.sentinel.output, mock.sentinel.error, -1
) )
mock_get_command.return_value = mock_command
retval = userdatautils.execute_user_data_script( retval = userdatautils.execute_user_data_script(
mock.sentinel.user_data) mock.sentinel.user_data)
self.assertEqual(-1, retval) self.assertEqual(-1, retval)