Add the data source base classes and the HTTP OpenStack implementation

Change-Id: Ia5fdbe6cbd6364f8ec671fec413fdd37741a2598
This commit is contained in:
Claudiu Popa 2015-06-04 12:42:16 +03:00
parent 765fe3c1f7
commit 6a61cf95f9
13 changed files with 981 additions and 5 deletions

View File

@ -52,8 +52,8 @@ class General(general.General):
def reboot(self): def reboot(self):
raise NotImplementedError raise NotImplementedError
def set_locale(self): def set_locale(self, locale):
raise NotImplementedError raise NotImplementedError
def set_timezone(self): def set_timezone(self, timezone):
raise NotImplementedError raise NotImplementedError

View File

@ -11,11 +11,15 @@ from ctypes import wintypes
import logging import logging
import subprocess import subprocess
from six.moves import urllib_parse
from cloudinit import exceptions from cloudinit import exceptions
from cloudinit.osys import base
from cloudinit.osys import network from cloudinit.osys import network
from cloudinit.osys.windows.util import iphlpapi from cloudinit.osys.windows.util import iphlpapi
from cloudinit.osys.windows.util import kernel32 from cloudinit.osys.windows.util import kernel32
from cloudinit.osys.windows.util import ws2_32 from cloudinit.osys.windows.util import ws2_32
from cloudinit import url_helper
MIB_IPPROTO_NETMGMT = 3 MIB_IPPROTO_NETMGMT = 3
@ -26,7 +30,8 @@ _PROTOCOL_TCP = "TCP"
_PROTOCOL_UDP = "UDP" _PROTOCOL_UDP = "UDP"
_ERROR_FILE_NOT_FOUND = 2 _ERROR_FILE_NOT_FOUND = 2
_ComputerNamePhysicalDnsHostname = 5 _ComputerNamePhysicalDnsHostname = 5
LOG = logging.getLogger(__file__) _MAX_URL_CHECK_RETRIES = 3
LOG = logging.getLogger(__name__)
def _heap_alloc(heap, size): def _heap_alloc(heap, size):
@ -37,6 +42,15 @@ def _heap_alloc(heap, size):
return table_mem return table_mem
def _check_url(url, retries_count=_MAX_URL_CHECK_RETRIES):
LOG.debug("Testing url: %s", url)
try:
url_helper.read_url(url, retries=retries_count)
return True
except url_helper.UrlError:
return False
class Network(network.Network): class Network(network.Network):
"""Network namespace object tailored for the Windows platform.""" """Network namespace object tailored for the Windows platform."""
@ -110,6 +124,42 @@ class Network(network.Network):
return next((r for r in self.routes() if r.destination == '0.0.0.0'), return next((r for r in self.routes() if r.destination == '0.0.0.0'),
None) None)
def set_metadata_ip_route(self, metadata_url):
"""Set a network route if the given metadata url can't be accessed.
This is a workaround for
https://bugs.launchpad.net/quantum/+bug/1174657.
"""
osutils = base.get_osutils()
if osutils.general.check_os_version(6, 0):
# 169.254.x.x addresses are not getting routed starting from
# Windows Vista / 2008
metadata_netloc = urllib_parse.urlparse(metadata_url).netloc
metadata_host = metadata_netloc.split(':')[0]
if not metadata_host.startswith("169.254."):
return
routes = self.routes()
exists_route = any(route.destination == metadata_host
for route in routes)
if not exists_route and not _check_url(metadata_url):
default_gateway = self.default_gateway()
if default_gateway:
try:
LOG.debug('Setting gateway for host: %s',
metadata_host)
route = Route(
destination=metadata_host,
netmask="255.255.255.255",
gateway=default_gateway.gateway,
interface=None, metric=None)
Route.add(route)
except Exception as ex:
# Ignore it
LOG.exception(ex)
# These are not required by the Windows version for now, # These are not required by the Windows version for now,
# but we provide them as noop version. # but we provide them as noop version.
def hosts(self): def hosts(self):

97
cloudinit/sources/base.py Normal file
View File

@ -0,0 +1,97 @@
# Copyright 2015 Canonical Ltd.
# This file is part of cloud-init. See LICENCE file for license information.
#
# vi: ts=4 expandtab
import abc
import six
class APIResponse(object):
"""Holds API response content
To access the content in the binary format, use the
`buffer` attribute, while the unicode content can be
accessed by calling `str` over this.
"""
def __init__(self, buffer, encoding="utf-8"):
self.buffer = buffer
self._encoding = encoding
def __str__(self):
return self.buffer.decode(self._encoding)
@six.add_metaclass(abc.ABCMeta)
class BaseDataSource(object):
"""Base class for the data sources."""
datasource_config = {}
def __init__(self, config=None):
self._cache = {}
# TODO(cpopa): merge them instead.
self._config = config or self.datasource_config
def _get_cache_data(self, path):
"""Do a metadata lookup for the given *path*
This will return the available metadata under *path*,
while caching the result, so that a next call will not do
an additional API call.
"""
if path not in self._cache:
self._cache[path] = self._get_data(path)
return self._cache[path]
@abc.abstractmethod
def load(self):
"""Try to load this metadata service.
This should return ``True`` if the service was loaded properly,
``False`` otherwise.
"""
@abc.abstractmethod
def _get_data(self, path):
"""Retrieve the metadata exported under the `path` key.
This should return an instance of :class:`APIResponse`.
"""
def instance_id(self):
"""Get this instance's id."""
def user_data(self):
"""Get the user data available for this instance."""
def vendor_data(self):
"""Get the vendor data available for this instance."""
def host_name(self):
"""Get the hostname available for this instance."""
def public_keys(self):
"""Get the public keys available for this instance."""
def network_config(self):
"""Get the specified network config, if any."""
def admin_password(self):
"""Get the admin password."""
def post_password(self, password):
"""Post the password to the metadata service."""
def can_update_password(self):
"""Check if this data source can update the admin password."""
def is_password_changed(self):
"""Check if the data source has a new password for this instance."""
return False
def is_password_set(self):
"""Check if the password was already posted to the metadata service."""

View File

View File

@ -0,0 +1,112 @@
# Copyright 2015 Canonical Ltd.
# This file is part of cloud-init. See LICENCE file for license information.
#
# vi: ts=4 expandtab
"""Base classes for interacting with OpenStack data sources."""
import abc
import json
import logging
import os
import six
from cloudinit.sources import base
__all__ = ('BaseOpenStackSource', )
_PAYLOAD_KEY = "content_path"
_ADMIN_PASSWORD = "admin_pass"
LOG = logging.getLogger(__name__)
_OS_LATEST = 'latest'
_OS_FOLSOM = '2012-08-10'
_OS_GRIZZLY = '2013-04-04'
_OS_HAVANA = '2013-10-17'
# Keep this in chronological order. New supported versions go at the end.
_OS_VERSIONS = (
_OS_FOLSOM,
_OS_GRIZZLY,
_OS_HAVANA,
)
@six.add_metaclass(abc.ABCMeta)
class BaseOpenStackSource(base.BaseDataSource):
"""Base classes for interacting with an OpenStack data source.
This is useful for both the HTTP data source, as well for
ConfigDrive.
"""
def __init__(self):
super(BaseOpenStackSource, self).__init__()
self._version = None
@abc.abstractmethod
def _available_versions(self):
"""Get the available metadata versions."""
@abc.abstractmethod
def _path_join(self, path, *addons):
"""Join one or more components together."""
def _working_version(self):
versions = self._available_versions()
# OS_VERSIONS is stored in chronological order, so
# reverse it to check newest first.
supported = reversed(_OS_VERSIONS)
selected_version = next((version for version in supported
if version in versions), _OS_LATEST)
LOG.debug("Selected version %r from %s", selected_version, versions)
return selected_version
def _get_content(self, name):
path = self._path_join('openstack', 'content', name)
return self._get_cache_data(path)
def _get_meta_data(self):
path = self._path_join('openstack', self._version, 'meta_data.json')
data = self._get_cache_data(path)
if data:
return json.loads(str(data))
def load(self):
self._version = self._working_version()
super(BaseOpenStackSource, self).load()
def user_data(self):
path = self._path_join('openstack', self._version, 'user_data')
return self._get_cache_data(path).buffer
def vendor_data(self):
path = self._path_join('openstack', self._version, 'vendor_data.json')
return self._get_cache_data(path).buffer
def instance_id(self):
return self._get_meta_data().get('uuid')
def host_name(self):
return self._get_meta_data().get('hostname')
def public_keys(self):
public_keys = self._get_meta_data().get('public_keys')
if public_keys:
return list(public_keys.values())
return []
def network_config(self):
network_config = self._get_meta_data().get('network_config')
if not network_config:
return None
if _PAYLOAD_KEY not in network_config:
return None
content_path = network_config[_PAYLOAD_KEY]
content_name = os.path.basename(content_path)
return str(self._get_content(content_name))
def admin_password(self):
meta_data = self._get_meta_data()
meta = meta_data.get('meta', {})
return meta.get(_ADMIN_PASSWORD) or meta_data.get(_ADMIN_PASSWORD)

View File

@ -0,0 +1,127 @@
# Copyright 2015 Canonical Ltd.
# This file is part of cloud-init. See LICENCE file for license information.
#
# vi: ts=4 expandtab
import logging
import os
import posixpath
import re
from cloudinit import exceptions
from cloudinit.osys import base
from cloudinit.sources import base as base_source
from cloudinit.sources.openstack import base as baseopenstack
from cloudinit import url_helper
LOG = logging.getLogger(__name__)
IS_WINDOWS = os.name == 'nt'
# Not necessarily the same as using datetime.strftime,
# but should be enough for our use case.
VERSION_REGEX = re.compile('^\d{4}-\d{2}-\d{2}$')
class HttpOpenStackSource(baseopenstack.BaseOpenStackSource):
"""Class for exporting the HTTP OpenStack data source."""
datasource_config = {
'max_wait': 120,
'timeout': 10,
'metadata_url': 'http://169.254.169.254/',
'post_password_version': '2013-04-04',
'retries': 3,
}
@staticmethod
def _enable_metadata_access(metadata_url):
if IS_WINDOWS:
osutils = base.get_osutils()
osutils.network.set_metadata_ip_route(metadata_url)
@staticmethod
def _path_join(path, *addons):
return posixpath.join(path, *addons)
@staticmethod
def _valid_api_version(version):
if version == 'latest':
return version
return VERSION_REGEX.match(version)
def _available_versions(self):
content = str(self._get_cache_data("openstack"))
versions = list(filter(None, content.splitlines()))
if not versions:
msg = 'No metadata versions were found.'
raise exceptions.CloudInitError(msg)
for version in versions:
if not self._valid_api_version(version):
msg = 'Invalid API version {!r}'.format(version)
raise exceptions.CloudInitError(msg)
return versions
def _get_data(self, path):
norm_path = self._path_join(self._config['metadata_url'], path)
LOG.debug('Getting metadata from: %s', norm_path)
response = url_helper.wait_any_url([norm_path],
timeout=self._config['timeout'],
max_wait=self._config['max_wait'])
if response:
_, request = response
return base_source.APIResponse(request.contents,
encoding=request.encoding)
msg = "Metadata for url {0} was not accessible in due time"
raise exceptions.CloudInitError(msg.format(norm_path))
def _post_data(self, path, data):
norm_path = self._path_join(self._config['metadata_url'], path)
LOG.debug('Posting metadata to: %s', norm_path)
url_helper.read_url(norm_path, data=data,
retries=self._config['retries'],
timeout=self._config['timeout'])
@property
def _password_path(self):
return 'openstack/%s/password' % self._version
def load(self):
metadata_url = self._config['metadata_url']
self._enable_metadata_access(metadata_url)
super(HttpOpenStackSource, self).load()
try:
self._get_meta_data()
return True
except Exception:
LOG.warning('Metadata not found at URL %r', metadata_url)
return False
def can_update_password(self):
"""Check if the password can be posted for the current data source."""
password = map(int, self._config['post_password_version'].split("-"))
if self._version == 'latest':
current = (0, )
else:
current = map(int, self._version.split("-"))
return tuple(current) >= tuple(password)
@property
def is_password_set(self):
path = self._password_path
content = self._get_cache_data(path).buffer
return len(content) > 0
def post_password(self, password):
try:
self._post_data(self._password_path, password)
return True
except url_helper.UrlError as ex:
if ex.status_code == url_helper.CONFLICT:
# Password already set
return False
else:
raise

View File

@ -8,6 +8,7 @@ import subprocess
import unittest import unittest
from cloudinit import exceptions from cloudinit import exceptions
from cloudinit.tests.util import LogSnatcher
from cloudinit.tests.util import mock from cloudinit.tests.util import mock
@ -15,7 +16,7 @@ class TestNetworkWindows(unittest.TestCase):
def setUp(self): def setUp(self):
self._ctypes_mock = mock.MagicMock() self._ctypes_mock = mock.MagicMock()
self._moves_mock = mock.Mock() self._winreg_mock = mock.Mock()
self._win32com_mock = mock.Mock() self._win32com_mock = mock.Mock()
self._wmi_mock = mock.Mock() self._wmi_mock = mock.Mock()
@ -24,7 +25,7 @@ class TestNetworkWindows(unittest.TestCase):
{'ctypes': self._ctypes_mock, {'ctypes': self._ctypes_mock,
'win32com': self._win32com_mock, 'win32com': self._win32com_mock,
'wmi': self._wmi_mock, 'wmi': self._wmi_mock,
'six.moves': self._moves_mock}) 'six.moves.winreg': self._winreg_mock})
self._module_patcher.start() self._module_patcher.start()
self._iphlpapi = mock.Mock() self._iphlpapi = mock.Mock()
@ -251,3 +252,106 @@ class TestNetworkWindows(unittest.TestCase):
self.assertEqual('dwForwardIfIndex', given_route.interface) self.assertEqual('dwForwardIfIndex', given_route.interface)
self.assertEqual('dwForwardMetric1', given_route.metric) self.assertEqual('dwForwardMetric1', given_route.metric)
self.assertEqual('dwForwardProto', given_route.flags) self.assertEqual('dwForwardProto', given_route.flags)
@mock.patch('cloudinit.osys.base.get_osutils')
@mock.patch('cloudinit.osys.windows.network.Network.routes')
def test_set_metadata_ip_route_not_called(self, mock_routes,
mock_osutils):
general = mock_osutils.return_value.general
general.check_os_version.return_value = False
self._network.set_metadata_ip_route(mock.sentinel.url)
self.assertFalse(mock_routes.called)
general.check_os_version.assert_called_once_with(6, 0)
@mock.patch('cloudinit.osys.base.get_osutils')
@mock.patch('cloudinit.osys.windows.network.Network.routes')
def test_set_metadata_ip_route_not_invalid_url(self, mock_routes,
mock_osutils):
general = mock_osutils.return_value.general
general.check_os_version.return_value = True
self._network.set_metadata_ip_route("http://169.253.169.253")
self.assertFalse(mock_routes.called)
general.check_os_version.assert_called_once_with(6, 0)
@mock.patch('cloudinit.osys.base.get_osutils')
@mock.patch('cloudinit.osys.windows.network.Network.routes')
@mock.patch('cloudinit.osys.windows.network.Network.default_gateway')
def test_set_metadata_ip_route_route_already_exists(
self, mock_default_gateway, mock_routes, mock_osutils):
mock_route = mock.Mock()
mock_route.destination = "169.254.169.254"
mock_routes.return_value = (mock_route, )
self._network.set_metadata_ip_route("http://169.254.169.254")
self.assertTrue(mock_routes.called)
self.assertFalse(mock_default_gateway.called)
@mock.patch('cloudinit.osys.base.get_osutils')
@mock.patch('cloudinit.osys.windows.network._check_url')
@mock.patch('cloudinit.osys.windows.network.Network.routes')
@mock.patch('cloudinit.osys.windows.network.Network.default_gateway')
def test_set_metadata_ip_route_route_missing_url_accessible(
self, mock_default_gateway, mock_routes,
mock_check_url, mock_osutils):
mock_routes.return_value = ()
mock_check_url.return_value = True
self._network.set_metadata_ip_route("http://169.254.169.254")
self.assertTrue(mock_routes.called)
self.assertFalse(mock_default_gateway.called)
self.assertTrue(mock_osutils.called)
@mock.patch('cloudinit.osys.base.get_osutils')
@mock.patch('cloudinit.osys.windows.network._check_url')
@mock.patch('cloudinit.osys.windows.network.Network.routes')
@mock.patch('cloudinit.osys.windows.network.Network.default_gateway')
@mock.patch('cloudinit.osys.windows.network.Route')
def test_set_metadata_ip_route_no_default_gateway(
self, mock_Route, mock_default_gateway,
mock_routes, mock_check_url, mock_osutils):
mock_routes.return_value = ()
mock_check_url.return_value = False
mock_default_gateway.return_value = None
self._network.set_metadata_ip_route("http://169.254.169.254")
self.assertTrue(mock_osutils.called)
self.assertTrue(mock_routes.called)
self.assertTrue(mock_default_gateway.called)
self.assertFalse(mock_Route.called)
@mock.patch('cloudinit.osys.base.get_osutils')
@mock.patch('cloudinit.osys.windows.network._check_url')
@mock.patch('cloudinit.osys.windows.network.Network.routes')
@mock.patch('cloudinit.osys.windows.network.Network.default_gateway')
@mock.patch('cloudinit.osys.windows.network.Route')
def test_set_metadata_ip_route(
self, mock_Route, mock_default_gateway,
mock_routes, mock_check_url, mock_osutils):
mock_routes.return_value = ()
mock_check_url.return_value = False
with LogSnatcher('cloudinit.osys.windows.network') as snatcher:
self._network.set_metadata_ip_route("http://169.254.169.254")
expected = ['Setting gateway for host: 169.254.169.254']
self.assertEqual(expected, snatcher.output)
self.assertTrue(mock_routes.called)
self.assertTrue(mock_default_gateway.called)
mock_Route.assert_called_once_with(
destination="169.254.169.254",
netmask="255.255.255.255",
gateway=mock_default_gateway.return_value.gateway,
interface=None, metric=None)
mock_Route.add.assert_called_once_with(mock_Route.return_value)
self.assertTrue(mock_osutils.called)

View File

View File

@ -0,0 +1,176 @@
# Copyright 2015 Canonical Ltd.
# This file is part of cloud-init. See LICENCE file for license information.
#
# vi: ts=4 expandtab
from cloudinit.sources import base as base_source
from cloudinit.sources.openstack import base
from cloudinit import test
from cloudinit.tests.util import LogSnatcher
from cloudinit.tests.util import mock
class TestBaseOpenStackSource(test.TestCase):
@mock.patch('cloudinit.sources.openstack.base.BaseOpenStackSource.'
'__abstractmethods__', new=())
def setUp(self):
self._source = base.BaseOpenStackSource()
super(TestBaseOpenStackSource, self).setUp()
@mock.patch('cloudinit.sources.openstack.base.BaseOpenStackSource.'
'_available_versions')
def _test_working_version(self, mock_available_versions,
versions, expected_version):
mock_available_versions.return_value = versions
with LogSnatcher('cloudinit.sources.openstack.base') as snatcher:
version = self._source._working_version()
msg = "Selected version '{0}' from {1}"
expected_logging = [msg.format(expected_version, versions)]
self.assertEqual(expected_logging, snatcher.output)
self.assertEqual(expected_version, version)
def test_working_version_latest(self):
self._test_working_version(versions=(), expected_version='latest')
def test_working_version_other_version(self):
versions = (
base._OS_FOLSOM,
base._OS_GRIZZLY,
base._OS_HAVANA,
)
self._test_working_version(versions=versions,
expected_version=base._OS_HAVANA)
@mock.patch('cloudinit.sources.openstack.base.BaseOpenStackSource.'
'_get_meta_data')
def test_metadata_capabilities(self, mock_get_meta_data):
mock_get_meta_data.return_value = {
'uuid': mock.sentinel.id,
'hostname': mock.sentinel.hostname,
'public_keys': {'key-one': 'key-one', 'key-two': 'key-two'},
}
instance_id = self._source.instance_id()
hostname = self._source.host_name()
public_keys = self._source.public_keys()
self.assertEqual(mock.sentinel.id, instance_id)
self.assertEqual(mock.sentinel.hostname, hostname)
self.assertEqual(["key-one", "key-two"], sorted(public_keys))
@mock.patch('cloudinit.sources.openstack.base.BaseOpenStackSource.'
'_get_meta_data')
def test_no_public_keys(self, mock_get_meta_data):
mock_get_meta_data.return_value = {'public_keys': []}
public_keys = self._source.public_keys()
self.assertEqual([], public_keys)
@mock.patch('cloudinit.sources.openstack.base.BaseOpenStackSource.'
'_get_meta_data')
def test_admin_password(self, mock_get_meta_data):
mock_get_meta_data.return_value = {
'meta': {base._ADMIN_PASSWORD: mock.sentinel.password}
}
password = self._source.admin_password()
self.assertEqual(mock.sentinel.password, password)
@mock.patch('cloudinit.sources.openstack.base.BaseOpenStackSource.'
'_path_join')
@mock.patch('cloudinit.sources.openstack.base.BaseOpenStackSource.'
'_get_cache_data')
def test_get_content(self, mock_get_cache_data, mock_path_join):
result = self._source._get_content(mock.sentinel.name)
mock_path_join.assert_called_once_with(
'openstack', 'content', mock.sentinel.name)
mock_get_cache_data.assert_called_once_with(
mock_path_join.return_value)
self.assertEqual(mock_get_cache_data.return_value, result)
@mock.patch('cloudinit.sources.openstack.base.BaseOpenStackSource.'
'_path_join')
@mock.patch('cloudinit.sources.openstack.base.BaseOpenStackSource.'
'_get_cache_data')
def test_user_data(self, mock_get_cache_data, mock_path_join):
result = self._source.user_data()
mock_path_join.assert_called_once_with(
'openstack', self._source._version, 'user_data')
mock_get_cache_data.assert_called_once_with(
mock_path_join.return_value)
self.assertEqual(mock_get_cache_data.return_value.buffer, result)
@mock.patch('cloudinit.sources.openstack.base.BaseOpenStackSource.'
'_path_join')
@mock.patch('cloudinit.sources.openstack.base.BaseOpenStackSource.'
'_get_cache_data')
def test_get_metadata(self, mock_get_cache_data, mock_path_join):
mock_get_cache_data.return_value = base_source.APIResponse(b"{}")
result = self._source._get_meta_data()
mock_path_join.assert_called_once_with(
'openstack', self._source._version, 'meta_data.json')
mock_get_cache_data.assert_called_once_with(
mock_path_join.return_value)
self.assertEqual({}, result)
@mock.patch('cloudinit.sources.openstack.base.BaseOpenStackSource.'
'_path_join')
@mock.patch('cloudinit.sources.openstack.base.BaseOpenStackSource.'
'_get_cache_data')
def test_vendor_data(self, mock_get_cache_data, mock_path_join):
result = self._source.vendor_data()
mock_path_join.assert_called_once_with(
'openstack', self._source._version, 'vendor_data.json')
mock_get_cache_data.assert_called_once_with(
mock_path_join.return_value)
self.assertEqual(mock_get_cache_data.return_value.buffer, result)
@mock.patch('cloudinit.sources.openstack.base.BaseOpenStackSource.'
'_working_version')
def test_load(self, mock_working_version):
self._source.load()
self.assertTrue(mock_working_version.called)
self.assertEqual(mock_working_version.return_value,
self._source._version)
@mock.patch('cloudinit.sources.openstack.base.BaseOpenStackSource.'
'_get_meta_data')
def test_network_config_no_config(self, mock_get_metadata):
mock_get_metadata.return_value = {}
self.assertIsNone(self._source.network_config())
mock_get_metadata.return_value = {1: 2}
self.assertIsNone(self._source.network_config())
@mock.patch('cloudinit.sources.openstack.base.BaseOpenStackSource.'
'_get_meta_data')
@mock.patch('cloudinit.sources.openstack.base.BaseOpenStackSource.'
'_get_content')
def test_network_config(self, mock_get_content, mock_get_metadata):
mock_get_metadata.return_value = {
"network_config": {base._PAYLOAD_KEY: "content_path"}
}
result = self._source.network_config()
mock_get_content.assert_called_once_with("content_path")
self.assertEqual(str(mock_get_content.return_value), result)
@mock.patch('cloudinit.sources.openstack.base.BaseOpenStackSource.'
'_get_data')
def test_get_cache_data(self, mock_get_data):
mock_get_data.return_value = b'test'
result = self._source._get_cache_data(mock.sentinel.path)
mock_get_data.assert_called_once_with(mock.sentinel.path)
self.assertEqual(b'test', result)

View File

@ -0,0 +1,251 @@
# Copyright 2015 Canonical Ltd.
# This file is part of cloud-init. See LICENCE file for license information.
#
# vi: ts=4 expandtab
import textwrap
from six.moves import http_client
from cloudinit import exceptions
from cloudinit.sources import base
from cloudinit.sources.openstack import httpopenstack
from cloudinit import test
from cloudinit.tests.util import LogSnatcher
from cloudinit.tests.util import mock
from cloudinit import url_helper
class TestHttpOpenStackSource(test.TestCase):
def setUp(self):
self._source = httpopenstack.HttpOpenStackSource()
super(TestHttpOpenStackSource, self).setUp()
@mock.patch.object(httpopenstack, 'IS_WINDOWS', new=False)
@mock.patch('cloudinit.osys.windows.network.Network.'
'set_metadata_ip_route')
def test__enable_metadata_access_not_nt(self, mock_set_metadata_ip_route):
self._source._enable_metadata_access(mock.sentinel.metadata_url)
self.assertFalse(mock_set_metadata_ip_route.called)
@mock.patch.object(httpopenstack, 'IS_WINDOWS', new=True)
@mock.patch('cloudinit.osys.base.get_osutils')
def test__enable_metadata_access_nt(self, mock_get_osutils):
self._source._enable_metadata_access(mock.sentinel.metadata_url)
mock_get_osutils.assert_called_once_with()
osutils = mock_get_osutils.return_value
osutils.network.set_metadata_ip_route.assert_called_once_with(
mock.sentinel.metadata_url)
def test__path_join(self):
calls = [
(('path', 'a', 'b'), 'path/a/b'),
(('path', ), 'path'),
(('path/', 'b/'), 'path/b/'),
]
for arguments, expected in calls:
path = self._source._path_join(*arguments)
self.assertEqual(expected, path)
@mock.patch('cloudinit.sources.openstack.httpopenstack.'
'HttpOpenStackSource._get_cache_data')
def test__available_versions(self, mock_get_cache_data):
mock_get_cache_data.return_value = textwrap.dedent("""
2013-02-02
2014-04-04
2015-05-05
latest""")
versions = self._source._available_versions()
expected = ['2013-02-02', '2014-04-04', '2015-05-05', 'latest']
mock_get_cache_data.assert_called_once_with("openstack")
self.assertEqual(expected, versions)
@mock.patch('cloudinit.sources.openstack.httpopenstack.'
'HttpOpenStackSource._get_cache_data')
def _test__available_versions_invalid_versions(
self, version, mock_get_cache_data):
mock_get_cache_data.return_value = version
exc = self.assertRaises(exceptions.CloudInitError,
self._source._available_versions)
expected = 'Invalid API version {!r}'.format(version)
self.assertEqual(expected, str(exc))
def test__available_versions_invalid_versions(self):
versions = ['2013-no-worky', '2012', '2012-02',
'lates', '20004-111-222', '2004-11-11111',
' 2004-11-20']
for version in versions:
self._test__available_versions_invalid_versions(version)
@mock.patch('cloudinit.sources.openstack.httpopenstack.'
'HttpOpenStackSource._get_cache_data')
def test__available_versions_no_version_found(self, mock_get_cache_data):
mock_get_cache_data.return_value = ''
exc = self.assertRaises(exceptions.CloudInitError,
self._source._available_versions)
self.assertEqual('No metadata versions were found.', str(exc))
@mock.patch('cloudinit.sources.openstack.httpopenstack.'
'HttpOpenStackSource._get_cache_data')
def _test_is_password_set(self, mock_get_cache_data, data, expected):
mock_get_cache_data.return_value = data
result = self._source.is_password_set
self.assertEqual(expected, result)
mock_get_cache_data.assert_called_once_with(
self._source._password_path)
def test_is_password_set(self):
empty_data = base.APIResponse(b"")
non_empty_data = base.APIResponse(b"password")
self._test_is_password_set(data=empty_data, expected=False)
self._test_is_password_set(data=non_empty_data, expected=True)
def _test_can_update_password(self, version, expected):
with mock.patch.object(self._source, '_version', new=version):
self.assertEqual(self._source.can_update_password(), expected)
def test_can_update_password(self):
self._test_can_update_password('2012-08-10', expected=False)
self._test_can_update_password('2012-11-10', expected=False)
self._test_can_update_password('2013-04-04', expected=True)
self._test_can_update_password('2014-04-04', expected=True)
self._test_can_update_password('latest', expected=False)
@mock.patch('cloudinit.sources.openstack.httpopenstack.'
'HttpOpenStackSource._path_join')
@mock.patch('cloudinit.url_helper.read_url')
def test__post_data(self, mock_read_url, mock_path_join):
with LogSnatcher('cloudinit.sources.openstack.'
'httpopenstack') as snatcher:
self._source._post_data(mock.sentinel.path,
mock.sentinel.data)
expected_logging = [
'Posting metadata to: %s' % mock_path_join.return_value
]
self.assertEqual(expected_logging, snatcher.output)
mock_path_join.assert_called_once_with(
self._source._config['metadata_url'], mock.sentinel.path)
mock_read_url.assert_called_once_with(
mock_path_join.return_value, data=mock.sentinel.data,
retries=self._source._config['retries'],
timeout=self._source._config['timeout'])
@mock.patch('cloudinit.sources.openstack.httpopenstack.'
'HttpOpenStackSource._post_data')
def test_post_password(self, mock_post_data):
self.assertTrue(self._source.post_password(mock.sentinel.password))
mock_post_data.assert_called_once_with(
self._source._password_path, mock.sentinel.password)
@mock.patch('cloudinit.sources.openstack.httpopenstack.'
'HttpOpenStackSource._post_data')
def test_post_password_already_posted(self, mock_post_data):
exc = url_helper.UrlError(None)
exc.status_code = http_client.CONFLICT
mock_post_data.side_effect = exc
self.assertFalse(self._source.post_password(mock.sentinel.password))
mock_post_data.assert_called_once_with(
self._source._password_path, mock.sentinel.password)
@mock.patch('cloudinit.sources.openstack.httpopenstack.'
'HttpOpenStackSource._post_data')
def test_post_password_other_error(self, mock_post_data):
exc = url_helper.UrlError(None)
exc.status_code = http_client.NOT_FOUND
mock_post_data.side_effect = exc
self.assertRaises(url_helper.UrlError,
self._source.post_password,
mock.sentinel.password)
mock_post_data.assert_called_once_with(
self._source._password_path, mock.sentinel.password)
@mock.patch('cloudinit.sources.openstack.base.'
'BaseOpenStackSource.load')
@mock.patch('cloudinit.sources.openstack.httpopenstack.'
'HttpOpenStackSource._get_meta_data')
@mock.patch('cloudinit.sources.openstack.httpopenstack.'
'HttpOpenStackSource._enable_metadata_access')
def _test_load(self, mock_enable_metadata_access,
mock_get_metadata, mock_load, expected,
expected_logging, metadata_side_effect=None):
mock_get_metadata.side_effect = metadata_side_effect
with LogSnatcher('cloudinit.sources.openstack.'
'httpopenstack') as snatcher:
response = self._source.load()
self.assertEqual(expected, response)
mock_enable_metadata_access.assert_called_once_with(
self._source._config['metadata_url'])
mock_load.assert_called_once_with()
mock_get_metadata.assert_called_once_with()
self.assertEqual(expected_logging, snatcher.output)
def test_load_works(self):
self._test_load(expected=True, expected_logging=[])
def test_load_fails(self):
expected_logging = [
'Metadata not found at URL %r'
% self._source._config['metadata_url']
]
self._test_load(expected=False,
expected_logging=expected_logging,
metadata_side_effect=ValueError)
@mock.patch('cloudinit.sources.openstack.httpopenstack.'
'HttpOpenStackSource._path_join')
@mock.patch('cloudinit.url_helper.wait_any_url')
def test__get_data_inaccessible_metadata(self, mock_wait_any_url,
mock_path_join):
mock_wait_any_url.return_value = None
mock_path_join.return_value = mock.sentinel.path_join
msg = "Metadata for url {0} was not accessible in due time"
expected = msg.format(mock.sentinel.path_join)
expected_logging = [
'Getting metadata from: %s' % mock.sentinel.path_join
]
with LogSnatcher('cloudinit.sources.openstack.'
'httpopenstack') as snatcher:
exc = self.assertRaises(exceptions.CloudInitError,
self._source._get_data, 'test')
self.assertEqual(expected, str(exc))
self.assertEqual(expected_logging, snatcher.output)
@mock.patch('cloudinit.sources.openstack.httpopenstack.'
'HttpOpenStackSource._path_join')
@mock.patch('cloudinit.url_helper.wait_any_url')
def test__get_data(self, mock_wait_any_url, mock_path_join):
mock_response = mock.Mock()
response = b"test"
mock_response.contents = response
mock_response.encoding = 'utf-8'
mock_wait_any_url.return_value = (None, mock_response)
mock_path_join.return_value = mock.sentinel.path_join
expected_logging = [
'Getting metadata from: %s' % mock.sentinel.path_join
]
with LogSnatcher('cloudinit.sources.openstack.'
'httpopenstack') as snatcher:
result = self._source._get_data('test')
self.assertEqual(expected_logging, snatcher.output)
self.assertIsInstance(result, base.APIResponse)
self.assertEqual('test', str(result))
self.assertEqual(b'test', result.buffer)

View File

@ -1,4 +1,62 @@
# Copyright 2015 Canonical Ltd.
# This file is part of cloud-init. See LICENCE file for license information.
#
# vi: ts=4 expandtab
import logging
try: try:
from unittest import mock from unittest import mock
except ImportError: except ImportError:
import mock # noqa import mock # noqa
# This is similar with unittest.TestCase.assertLogs from Python 3.4.
class SnatchHandler(logging.Handler):
def __init__(self, *args, **kwargs):
super(SnatchHandler, self).__init__(*args, **kwargs)
self.output = []
def emit(self, record):
msg = self.format(record)
self.output.append(msg)
class LogSnatcher(object):
"""A context manager to capture emitted logged messages.
The class can be used as following::
with LogSnatcher('plugins.windows.createuser') as snatcher:
LOG.info("doing stuff")
LOG.info("doing stuff %s", 1)
LOG.warn("doing other stuff")
...
self.assertEqual(snatcher.output,
['INFO:unknown:doing stuff',
'INFO:unknown:doing stuff 1',
'WARN:unknown:doing other stuff'])
"""
@property
def output(self):
"""Get the output of this Snatcher.
The output is a list of log messages, already formatted.
"""
return self._snatch_handler.output
def __init__(self, logger_name):
self._logger_name = logger_name
self._snatch_handler = SnatchHandler()
self._logger = logging.getLogger(self._logger_name)
self._previous_level = self._logger.getEffectiveLevel()
def __enter__(self):
self._logger.setLevel(logging.DEBUG)
self._logger.handlers.append(self._snatch_handler)
return self
def __exit__(self, *args):
self._logger.handlers.remove(self._snatch_handler)
self._logger.setLevel(self._previous_level)

View File

@ -24,6 +24,7 @@ from six.moves.urllib.parse import urlparse # noqa
from six.moves.urllib.parse import urlunparse # noqa from six.moves.urllib.parse import urlunparse # noqa
from six.moves.http_client import BAD_REQUEST as _BAD_REQUEST from six.moves.http_client import BAD_REQUEST as _BAD_REQUEST
from six.moves.http_client import CONFLICT # noqa
from six.moves.http_client import MULTIPLE_CHOICES as _MULTIPLE_CHOICES from six.moves.http_client import MULTIPLE_CHOICES as _MULTIPLE_CHOICES
from six.moves.http_client import OK from six.moves.http_client import OK