# Copyright 2015 Cloudbase Solutions Srl # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. import os import sys import tempfile import textwrap import unittest try: import unittest.mock as mock except ImportError: import mock from cloudbaseinit import conf as cloudbaseinit_conf from cloudbaseinit import exception from cloudbaseinit.plugins.common.userdataplugins import cloudconfig from cloudbaseinit.plugins.common.userdataplugins.cloudconfigplugins import ( write_files ) from cloudbaseinit.tests import testutils CONF = cloudbaseinit_conf.CONF def _create_tempfile(): fd, tmp = tempfile.mkstemp() os.close(fd) return tmp class WriteFilesPluginTests(unittest.TestCase): @classmethod def setUpClass(cls): cls.plugin = cloudconfig.CloudConfigPlugin() def _get_tempfile(self): """Get a temporary file, usable by write_files plugin.""" tmp = _create_tempfile() self.addCleanup(os.remove, tmp) # In order to remove the file, we'll need to reset # the permissions set by write_file. self.addCleanup(os.chmod, tmp, 0o666) return tmp def test_decode_steps(self): pairs = [ ('gz', [write_files.GZIP_MIME]), ('gzip', [write_files.GZIP_MIME]), ('b64', [write_files.BASE64_MIME]), ('base64', [write_files.BASE64_MIME]), ('gz+b64', [write_files.BASE64_MIME, write_files.GZIP_MIME]), ('gzip+b64', [write_files.BASE64_MIME, write_files.GZIP_MIME]), ('gz+base64', [write_files.BASE64_MIME, write_files.GZIP_MIME]), ('gzip+base64', [write_files.BASE64_MIME, write_files.GZIP_MIME]), ('fake', []), ('', []), ] for param, expected in pairs: self.assertEqual(expected, write_files._decode_steps(param)) def test_process_permissions(self): for permissions in (0o644, '0644', '0o644', 420, 420.1): self.assertEqual( 420, write_files._convert_permissions(permissions)) with testutils.LogSnatcher('cloudbaseinit.plugins.common.' 'userdataplugins.cloudconfigplugins.' 'write_files') as snatcher: response = write_files._convert_permissions(mock.sentinel.invalid) expected_logging = [ 'Fail to process permissions %s, assuming 420' % mock.sentinel.invalid ] self.assertEqual(expected_logging, snatcher.output) self.assertEqual(write_files.DEFAULT_PERMISSIONS, response) @mock.patch('os.makedirs') def test_write_file(self, mk): path = u'fake_path' content = u'fake_content' result = write_files._write_file(path, content, open_mode="w") os.remove(path) self.assertTrue(result) @mock.patch('os.makedirs') def test_write_file_excp(self, mock_makedirs): mock_makedirs.side_effect = OSError result = write_files._write_file(u'fake_path', u'fake_content') self.assertFalse(result) def test_write_file_list(self): expected_logging = [ "Plugin 'invalid' is currently not supported", ] code = textwrap.dedent(""" write_files: - encoding: b64 content: NDI= path: {} permissions: '0o466' invalid: - stuff: 1 """) self._test_write_file(code, expected_logging) def test_write_file_dict(self): code = textwrap.dedent(""" write_files: encoding: b64 content: NDI= path: {} permissions: '0o466' """) self._test_write_file(code) def _test_write_file(self, code, expected_logging=None): tmp = self._get_tempfile() code = code.format(tmp) with testutils.LogSnatcher('cloudbaseinit.plugins.common.' 'userdataplugins.cloudconfig') as snatcher: self.plugin.process_non_multipart(code) self.assertTrue(os.path.exists(tmp), "Expected path does not exist.") with open(tmp) as stream: self.assertEqual('42', stream.read()) if expected_logging is not None: self.assertEqual(expected_logging, snatcher.output) # Test that the proper permissions were set. On Windows, # only the read bit is processed, the rest are ignored. permission = oct(os.stat(tmp).st_mode & 0o777) if sys.platform == 'win32': self.assertEqual(0o444, int(permission, 8)) else: self.assertEqual(0o466, int(permission, 8)) def test_missing_required_keys(self): code = textwrap.dedent(""" write_files: - c0ntent: NDI= """) expected_return = [ "Missing required keys from file " "information {'c0ntent': 'NDI='}" ] with testutils.LogSnatcher('cloudbaseinit.plugins.common.' 'userdataplugins.cloudconfigplugins.' 'write_files') as snatcher: self.plugin.process_non_multipart(code) self.assertEqual(expected_return, snatcher.output) @mock.patch('cloudbaseinit.plugins.common.userdataplugins.' 'cloudconfigplugins.write_files.WriteFilesPlugin.process') def test_processing_plugin_failed(self, mock_write_files): mock_write_files.side_effect = ValueError code = textwrap.dedent(""" write_files: - content: NDI= path: random_cloudbaseinit_test """) with testutils.LogSnatcher('cloudbaseinit.plugins.common.' 'userdataplugins.cloudconfig') as snatcher: self.plugin.process_non_multipart(code) self.assertTrue(snatcher.output[0].startswith( "Processing plugin write_files failed")) self.assertTrue(snatcher.output[0].endswith("ValueError")) self.assertFalse(os.path.exists('random_cloudbaseinit_test')) def test_wrong_gzip_content(self): tmp = self._get_tempfile() code = textwrap.dedent(""" write_files: - content: lala encoding: gz path: {} """.format(tmp)) with testutils.LogSnatcher('cloudbaseinit.plugins.common.' 'userdataplugins.cloudconfigplugins.' 'write_files') as snatcher: self.plugin.process_non_multipart(code) self.assertTrue(snatcher.output[0].startswith( "Fail to decompress gzip content")) def test_wrong_b64_content(self): tmp = self._get_tempfile() code = textwrap.dedent(""" write_files: - content: l encoding: b64 path: {} """.format(tmp)) with testutils.LogSnatcher('cloudbaseinit.plugins.common.' 'userdataplugins.cloudconfigplugins.' 'write_files') as snatcher: self.plugin.process_non_multipart(code) self.assertTrue(snatcher.output[0].startswith( "Fail to decode base64 content.")) def test_unknown_encoding(self): tmp = self._get_tempfile() code = textwrap.dedent(""" write_files: - content: NDI= path: {} encoding: unknown_encoding permissions: '0o466' """.format(tmp)) with testutils.LogSnatcher('cloudbaseinit.plugins.common.' 'userdataplugins.cloudconfigplugins.' 'write_files') as snatcher: self.plugin.process_non_multipart(code) self.assertTrue(os.path.exists(tmp), "Expected path does not exist.") with open(tmp) as stream: self.assertEqual('NDI=', stream.read()) self.assertEqual(["Unknown encoding, assuming plain text."], snatcher.output) def test_missing_encoding(self): tmp = self._get_tempfile() code = textwrap.dedent(""" write_files: - content: NDI= path: {} permissions: '0o466' """.format(tmp)) with testutils.LogSnatcher('cloudbaseinit.plugins.common.' 'userdataplugins.cloudconfigplugins.' 'write_files') as snatcher: self.plugin.process_non_multipart(code) self.assertTrue(os.path.exists(tmp), "Expected path does not exist.") with open(tmp) as stream: self.assertEqual('NDI=', stream.read()) self.assertEqual([], snatcher.output) def test_invalid_object_passed(self): with self.assertRaises(exception.CloudbaseInitException) as cm: write_files.WriteFilesPlugin().process(1) expected = "Can't process the type of data %r" % type(1) self.assertEqual(expected, str(cm.exception))