add some unit tests, fix things found by doing so

This commit is contained in:
Scott Moser 2013-07-09 01:35:49 -04:00
parent 62b20b4b0a
commit 9911c8f5cd
2 changed files with 188 additions and 6 deletions

View File

@ -18,6 +18,7 @@
import base64
import os
import os.path
from xml.dom import minidom
from cloudinit import log as logging
@ -29,7 +30,7 @@ LOG = logging.getLogger(__name__)
DS_NAME = 'Azure'
DEFAULT_METADATA = {"instance-id": "iid-AZURE-NODE"}
AGENT_START = ['service', 'walinuxagent', 'start']
DEFAULT_DS_CONFIG = {'datasource': {DS_NAME: {'agent_command': AGENT_START}}}
BUILTIN_DS_CONFIG = {'datasource': {DS_NAME: {'agent_command': AGENT_START}}}
class DataSourceAzureNet(sources.DataSource):
@ -56,16 +57,17 @@ class DataSourceAzureNet(sources.DataSource):
ret = load_azure_ds_dir(cdev)
except NonAzureDataSource:
pass
continue
except BrokenAzureDataSource as exc:
raise exc
except util.MountFailedError:
LOG.warn("%s was not mountable" % cdev)
continue
(md, self.userdata_raw, cfg) = ret
self.seed = cdev
self.metadata = util.mergemanydict([md, DEFAULT_METADATA])
self.cfg = util.mergemanydict([cfg, DEFAULT_DS_CONFIG])
self.cfg = cfg
found = cdev
LOG.debug("found datasource in %s", cdev)
@ -76,17 +78,25 @@ class DataSourceAzureNet(sources.DataSource):
path = ['datasource', DS_NAME, 'agent_command']
cmd = None
for cfg in (self.cfg, self.sys_cfg):
for cfg in (self.cfg, self.sys_cfg, BUILTIN_DS_CONFIG):
cmd = util.get_cfg_by_path(cfg, keyp=path)
if cmd is not None:
break
invoke_agent(cmd)
try:
invoke_agent(cmd)
except util.ProcessExecutionError:
# claim the datasource even if the command failed
util.logexc(LOG, "agent command '%s' failed.", cmd)
return True
def get_config_obj(self):
return self.cfg
def invoke_agent(cmd):
# this is a function itself to simplify patching it for test
if cmd:
LOG.debug("invoking agent: %s" % cmd)
util.subp(cmd, shell=(not isinstance(cmd, list)))
@ -105,7 +115,11 @@ def find_child(node, filter_func):
def read_azure_ovf(contents):
dom = minidom.parseString(contents)
try:
dom = minidom.parseString(contents)
except Exception as e:
raise NonAzureDataSource("invalid xml: %s" % e)
results = find_child(dom.documentElement,
lambda n: n.localName == "ProvisioningSection")

View File

@ -0,0 +1,168 @@
from cloudinit import helpers
from cloudinit.sources import DataSourceAzure
from tests.unittests.helpers import populate_dir
import base64
from mocker import MockerTestCase
import os
import yaml
def construct_valid_ovf_env(data=None, pubkeys=None, userdata=None):
if data is None:
data = {'HostName': 'FOOHOST'}
if pubkeys is None:
pubkeys = {}
content = """<?xml version="1.0" encoding="utf-8"?>
<Environment xmlns="http://schemas.dmtf.org/ovf/environment/1"
xmlns:oe="http://schemas.dmtf.org/ovf/environment/1"
xmlns:wa="http://schemas.microsoft.com/windowsazure"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
<wa:ProvisioningSection><wa:Version>1.0</wa:Version>
<LinuxProvisioningConfigurationSet
xmlns="http://schemas.microsoft.com/windowsazure"
xmlns:i="http://www.w3.org/2001/XMLSchema-instance">
<ConfigurationSetType>LinuxProvisioningConfiguration</ConfigurationSetType>
"""
for key, val in data.items():
content += "<%s>%s</%s>\n" % (key, val, key)
if userdata:
content += "<UserData>%s</UserData>\n" % (base64.b64encode(userdata))
if pubkeys:
content += "<SSH><PublicKeys>\n"
for fp, path in pubkeys.items():
content += " <PublicKey>"
content += ("<Fingerprint>%s</Fingerprint><Path>%s</Path>" %
(fp, path))
content += " </PublicKey>"
content += """
</LinuxProvisioningConfigurationSet>
</wa:ProvisioningSection>
<wa:PlatformSettingsSection><wa:Version>1.0</wa:Version>
<PlatformSettings xmlns="http://schemas.microsoft.com/windowsazure"
xmlns:i="http://www.w3.org/2001/XMLSchema-instance">
<KmsServerHostname>kms.core.windows.net</KmsServerHostname>
<ProvisionGuestAgent>false</ProvisionGuestAgent>
<GuestAgentPackageName i:nil="true" />
</PlatformSettings></wa:PlatformSettingsSection>
</Environment>
"""
return content
class TestAzureDataSource(MockerTestCase):
def setUp(self):
# makeDir comes from MockerTestCase
self.tmp = self.makeDir()
# patch cloud_dir, so our 'seed_dir' is guaranteed empty
self.paths = helpers.Paths({'cloud_dir': self.tmp})
self.unapply = []
super(TestAzureDataSource, self).setUp()
def tearDown(self):
apply_patches([i for i in reversed(self.unapply)])
super(TestAzureDataSource, self).tearDown()
def apply_patches(self, patches):
ret = apply_patches(patches)
self.unapply += ret
def _get_ds(self, data):
def dsdevs():
return data.get('dsdevs', [])
def invoker(cmd):
data['agent_invoked'] = cmd
if data.get('ovfcontent') is not None:
populate_dir(os.path.join(self.paths.seed_dir, "azure"),
{'ovf-env.xml': data['ovfcontent']})
mod = DataSourceAzure
if data.get('dsdevs'):
self.apply_patches([(mod, 'list_possible_azure_ds_devs', dsdevs)])
self.apply_patches([(mod, 'invoke_agent', invoker)])
dsrc = mod.DataSourceAzureNet(
data.get('sys_cfg', {}), distro=None, paths=self.paths)
return dsrc
def test_basic_seed_dir(self):
odata = {'HostName': "myhost", 'UserName': "myuser"}
data = {'ovfcontent': construct_valid_ovf_env(data=odata),
'sys_cfg': {}}
dsrc = self._get_ds(data)
ret = dsrc.get_data()
self.assertTrue(ret)
self.assertEqual(dsrc.userdata_raw, "")
self.assertEqual(dsrc.metadata['local-hostname'], odata['HostName'])
def test_user_cfg_set_agent_command(self):
cfg = {'agent_command': "my_command"}
odata = {'HostName': "myhost", 'UserName': "myuser",
'dscfg': yaml.dump(cfg)}
data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
dsrc = self._get_ds(data)
ret = dsrc.get_data()
self.assertTrue(ret)
self.assertEqual(data['agent_invoked'], cfg['agent_command'])
def test_sys_cfg_set_agent_command(self):
sys_cfg = {'datasource': {'Azure': {'agent_command': '_COMMAND'}}}
data = {'ovfcontent': construct_valid_ovf_env(data={}),
'sys_cfg': sys_cfg}
dsrc = self._get_ds(data)
ret = dsrc.get_data()
self.assertTrue(ret)
self.assertEqual(data['agent_invoked'], '_COMMAND')
def test_userdata_found(self):
mydata = "FOOBAR"
odata = {'UserData': base64.b64encode(mydata)}
data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
dsrc = self._get_ds(data)
ret = dsrc.get_data()
self.assertTrue(ret)
self.assertEqual(dsrc.userdata_raw, mydata)
def test_no_datasource_expected(self):
#no source should be found if no seed_dir and no devs
data = {}
dsrc = self._get_ds({})
ret = dsrc.get_data()
self.assertFalse(ret)
self.assertFalse('agent_invoked' in data)
class TestReadAzureOvf(MockerTestCase):
def test_invalid_xml_raises_non_azure_ds(self):
invalid_xml = "<foo>" + construct_valid_ovf_env(data={})
self.assertRaises(DataSourceAzure.NonAzureDataSource,
DataSourceAzure.read_azure_ovf, invalid_xml)
def apply_patches(patches):
ret = []
for (ref, name, replace) in patches:
if replace is None:
continue
orig = getattr(ref, name)
setattr(ref, name, replace)
ret.append((ref, name, orig))
return ret