Normalize all metadata providers and plugins

Every meta data service should return bytes only for these capabilities:
    * get_content
    * get_user_data
While `_get_meta_data` and any other method derrived from it
(including public keys, certificates etc.) should return homogeneous
data types and only strings, not bytes.
The decoding procedure is handled at its roots, not in the plugins
and is done by only using `encoding.get_as_string` function.

Fixed bugs:
    * invalid certificate splitting under maas service which usually
      generated an extra invalid certificate (empty string + footer)
    * text operations on bytes in maas and cloudstack (split, comparing)
    * multiple types for certificates (now only strings)
    * not receiving bytes from opennebula service when using `get_user_data`
      (which leads to crash under later processing through io.BytesIO)
    * erroneous certificate parsing/stripping/replacing under x509 importing
      (footer remains, not all possible EOLs replaced as it should)

Also added new and refined actual misleading unittests.

Change-Id: I704c43f5f784458a881293d761a21e62aed85732
This commit is contained in:
Cosmin Poieana 2015-06-04 16:44:16 +03:00
parent 55880b8ff8
commit ae15fee086
14 changed files with 163 additions and 90 deletions

View File

@ -20,6 +20,7 @@ import time
from oslo.config import cfg
from cloudbaseinit.openstack.common import log as logging
from cloudbaseinit.utils import encoding
opts = [
@ -88,13 +89,17 @@ class BaseMetadataService(object):
else:
raise
def _get_cache_data(self, path):
if path in self._cache:
def _get_cache_data(self, path, decode=False):
"""Get meta data with caching and decoding support."""
key = (path, decode)
if key in self._cache:
LOG.debug("Using cached copy of metadata: '%s'" % path)
return self._cache[path]
return self._cache[key]
else:
data = self._exec_with_retry(lambda: self._get_data(path))
self._cache[path] = data
if decode:
data = encoding.get_as_string(data)
self._cache[key] = data
return data
def get_instance_id(self):

View File

@ -51,9 +51,9 @@ class BaseOpenStackService(base.BaseMetadataService):
def _get_meta_data(self, version='latest'):
path = posixpath.normpath(
posixpath.join('openstack', version, 'meta_data.json'))
data = self._get_cache_data(path)
data = self._get_cache_data(path, decode=True)
if data:
return json.loads(encoding.get_as_string(data))
return json.loads(data)
def get_instance_id(self):
return self._get_meta_data().get('uuid')
@ -136,10 +136,10 @@ class BaseOpenStackService(base.BaseMetadataService):
if not certs:
# Look if the user_data contains a PEM certificate
try:
user_data = self.get_user_data()
user_data = self.get_user_data().strip()
if user_data.startswith(
x509constants.PEM_HEADER.encode()):
certs.append(user_data)
certs.append(encoding.get_as_string(user_data))
except base.NotExistingMetadataException:
LOG.debug("user_data metadata not present")

View File

@ -21,6 +21,7 @@ from six.moves import urllib
from cloudbaseinit.metadata.services import base
from cloudbaseinit.openstack.common import log as logging
from cloudbaseinit.osutils import factory as osutils_factory
from cloudbaseinit.utils import encoding
LOG = logging.getLogger(__name__)
@ -104,11 +105,11 @@ class CloudStack(base.BaseMetadataService):
def get_instance_id(self):
"""Instance name of the virtual machine."""
return self._get_cache_data('instance-id')
return self._get_cache_data('instance-id', decode=True)
def get_host_name(self):
"""Hostname of the virtual machine."""
return self._get_cache_data('local-hostname')
return self._get_cache_data('local-hostname', decode=True)
def get_user_data(self):
"""User data for this virtual machine."""
@ -117,7 +118,9 @@ class CloudStack(base.BaseMetadataService):
def get_public_keys(self):
"""Available ssh public keys."""
ssh_keys = []
for ssh_key in self._get_cache_data('public-keys').splitlines():
ssh_chunks = self._get_cache_data('public-keys',
decode=True).splitlines()
for ssh_key in ssh_chunks:
ssh_key = ssh_key.strip()
if not ssh_key:
continue
@ -155,14 +158,13 @@ class CloudStack(base.BaseMetadataService):
if response.status != 200:
LOG.warning("Getting password failed: %(status)s "
"%(reason)s - %(message)s",
"%(reason)s - %(message)r",
{"status": response.status,
"reason": response.reason,
"message": response.read()})
continue
content = response.read()
content = content.strip()
content = response.read().strip()
if not content:
LOG.warning("The Password Server did not have any "
"password for the current instance.")
@ -180,7 +182,7 @@ class CloudStack(base.BaseMetadataService):
LOG.info("The password server return a valid password "
"for the current instance.")
password = content.decode()
password = encoding.get_as_string(content)
break
return password
@ -201,14 +203,14 @@ class CloudStack(base.BaseMetadataService):
response = connection.getresponse()
if response.status != 200:
LOG.warning("Removing password failed: %(status)s "
"%(reason)s - %(message)s",
"%(reason)s - %(message)r",
{"status": response.status,
"reason": response.reason,
"message": response.read()})
continue
content = response.read()
if content.decode() != BAD_REQUEST:
if content != BAD_REQUEST: # comparing bytes with bytes
LOG.info("The password was removed from the Password Server.")
break
else:

View File

@ -78,25 +78,25 @@ class EC2Service(base.BaseMetadataService):
def get_host_name(self):
return self._get_cache_data('%s/meta-data/local-hostname' %
self._metadata_version)
self._metadata_version, decode=True)
def get_instance_id(self):
return self._get_cache_data('%s/meta-data/instance-id' %
self._metadata_version)
self._metadata_version, decode=True)
def get_public_keys(self):
ssh_keys = []
keys_info = self._get_cache_data(
'%s/meta-data/public-keys' %
self._metadata_version).split("\n")
self._metadata_version, decode=True).splitlines()
for key_info in keys_info:
(idx, key_name) = key_info.split('=')
ssh_key = self._get_cache_data(
'%(version)s/meta-data/public-keys/%(idx)s/openssh-key' %
{'version': self._metadata_version, 'idx': idx})
{'version': self._metadata_version, 'idx': idx}, decode=True)
ssh_keys.append(ssh_key.strip())
return ssh_keys

View File

@ -13,6 +13,7 @@
# under the License.
import posixpath
import re
from oauthlib import oauth1
from oslo.config import cfg
@ -108,23 +109,25 @@ class MaaSHttpService(base.BaseMetadataService):
def get_host_name(self):
return self._get_cache_data('%s/meta-data/local-hostname' %
self._metadata_version)
self._metadata_version, decode=True)
def get_instance_id(self):
return self._get_cache_data('%s/meta-data/instance-id' %
self._metadata_version)
def _get_list_from_text(self, text, delimiter):
return [v + delimiter for v in text.split(delimiter)]
self._metadata_version, decode=True)
def get_public_keys(self):
return self._get_cache_data('%s/meta-data/public-keys' %
self._metadata_version).splitlines()
self._metadata_version,
decode=True).splitlines()
def get_client_auth_certs(self):
return self._get_list_from_text(
self._get_cache_data('%s/meta-data/x509' % self._metadata_version),
"%s\n" % x509constants.PEM_FOOTER)
certs_data = self._get_cache_data('%s/meta-data/x509' %
self._metadata_version,
decode=True)
pattern = r"{begin}[\s\S]+?{end}".format(
begin=x509constants.PEM_HEADER,
end=x509constants.PEM_FOOTER)
return re.findall(pattern, certs_data)
def get_user_data(self):
return self._get_cache_data('%s/user-data' % self._metadata_version)

View File

@ -146,7 +146,7 @@ class OpenNebulaService(base.BaseMetadataService):
raise base.NotExistingMetadataException(msg)
return self._dict_content[name]
def _get_cache_data(self, names, iid=None):
def _get_cache_data(self, names, iid=None, decode=False):
# Solves caching issues when working with
# multiple names (lists not hashable).
# This happens because the caching function used
@ -160,7 +160,8 @@ class OpenNebulaService(base.BaseMetadataService):
names[ind] = value.format(iid=iid)
for name in names:
try:
return super(OpenNebulaService, self)._get_cache_data(name)
return super(OpenNebulaService, self)._get_cache_data(
name, decode=decode)
except base.NotExistingMetadataException:
pass
msg = "None of {} metadata was found".format(", ".join(names))
@ -192,14 +193,13 @@ class OpenNebulaService(base.BaseMetadataService):
return INSTANCE_ID
def get_host_name(self):
return encoding.get_as_string(self._get_cache_data(HOST_NAME))
return self._get_cache_data(HOST_NAME, decode=True)
def get_user_data(self):
return self._get_cache_data(USER_DATA)
def get_public_keys(self):
return encoding.get_as_string(
self._get_cache_data(PUBLIC_KEY)).splitlines()
return self._get_cache_data(PUBLIC_KEY, decode=True).splitlines()
def get_network_details(self):
"""Return a list of NetworkDetails objects.
@ -215,19 +215,17 @@ class OpenNebulaService(base.BaseMetadataService):
for iid in range(ncount):
try:
# get existing values
mac = encoding.get_as_string(
self._get_cache_data(MAC, iid=iid)).upper()
address = encoding.get_as_string(self._get_cache_data(ADDRESS,
iid=iid))
mac = self._get_cache_data(MAC, iid=iid, decode=True).upper()
address = self._get_cache_data(ADDRESS, iid=iid, decode=True)
# try to find/predict and compute the rest
try:
gateway = encoding.get_as_string(
self._get_cache_data(GATEWAY, iid=iid))
gateway = self._get_cache_data(GATEWAY, iid=iid,
decode=True)
except base.NotExistingMetadataException:
gateway = None
try:
netmask = encoding.get_as_string(
self._get_cache_data(NETMASK, iid=iid))
netmask = self._get_cache_data(NETMASK, iid=iid,
decode=True)
except base.NotExistingMetadataException:
if not gateway:
raise
@ -244,8 +242,8 @@ class OpenNebulaService(base.BaseMetadataService):
broadcast=broadcast,
gateway=gateway,
gateway6=None,
dnsnameservers=encoding.get_as_string(
self._get_cache_data(DNSNS, iid=iid)).split(" ")
dnsnameservers=self._get_cache_data(
DNSNS, iid=iid, decode=True).split(" ")
)
except base.NotExistingMetadataException:
LOG.debug("Incomplete NIC details")

View File

@ -70,11 +70,11 @@ class TestBaseOpenStackService(unittest.TestCase):
@mock.patch(MODPATH +
".BaseOpenStackService._get_cache_data")
def test_get_meta_data(self, mock_get_cache_data):
mock_get_cache_data.return_value = b'{"fake": "data"}'
mock_get_cache_data.return_value = '{"fake": "data"}'
response = self._service._get_meta_data(
version='fake version')
path = posixpath.join('openstack', 'fake version', 'meta_data.json')
mock_get_cache_data.assert_called_with(path)
mock_get_cache_data.assert_called_with(path, decode=True)
self.assertEqual({"fake": "data"}, response)
@mock.patch(MODPATH +
@ -151,7 +151,7 @@ class TestBaseOpenStackService(unittest.TestCase):
if isinstance(ret_value, bytes) and ret_value.startswith(
x509constants.PEM_HEADER.encode()):
mock_get_user_data.assert_called_once_with()
self.assertEqual([ret_value], response)
self.assertEqual([ret_value.decode()], response)
elif ret_value is base.NotExistingMetadataException:
self.assertFalse(response)
else:

View File

@ -12,6 +12,7 @@
# License for the specific language governing permissions and limitations
# under the License.
import functools
import socket
import unittest
@ -149,12 +150,17 @@ class CloudStackTest(unittest.TestCase):
@mock.patch('cloudbaseinit.metadata.services.cloudstack.CloudStack'
'._get_cache_data')
def _test_cache_response(self, mock_get_cache_data, method, metadata):
def _test_cache_response(self, mock_get_cache_data, method, metadata,
decode=True):
mock_get_cache_data.side_effect = [mock.sentinel.response]
response = method()
self.assertEqual(mock.sentinel.response, response)
mock_get_cache_data.assert_called_once_with(metadata)
cache_assert = functools.partial(
mock_get_cache_data.assert_called_once_with,
metadata)
if decode:
cache_assert(decode=decode)
def test_get_instance_id(self):
self._test_cache_response(method=self._service.get_instance_id,
@ -166,7 +172,7 @@ class CloudStackTest(unittest.TestCase):
def test_get_user_data(self):
self._test_cache_response(method=self._service.get_user_data,
metadata='../user-data')
metadata='../user-data', decode=False)
@mock.patch('cloudbaseinit.metadata.services.cloudstack.CloudStack'
'._get_cache_data')

View File

@ -99,7 +99,8 @@ class EC2ServiceTest(unittest.TestCase):
def test_get_host_name(self, mock_get_cache_data):
response = self._service.get_host_name()
mock_get_cache_data.assert_called_once_with(
'%s/meta-data/local-hostname' % self._service._metadata_version)
'%s/meta-data/local-hostname' % self._service._metadata_version,
decode=True)
self.assertEqual(mock_get_cache_data.return_value, response)
@mock.patch('cloudbaseinit.metadata.services.ec2service.EC2Service'
@ -107,7 +108,8 @@ class EC2ServiceTest(unittest.TestCase):
def test_get_instance_id(self, mock_get_cache_data):
response = self._service.get_instance_id()
mock_get_cache_data.assert_called_once_with(
'%s/meta-data/instance-id' % self._service._metadata_version)
'%s/meta-data/instance-id' % self._service._metadata_version,
decode=True)
self.assertEqual(mock_get_cache_data.return_value, response)
@mock.patch('cloudbaseinit.metadata.services.ec2service.EC2Service'
@ -117,10 +119,11 @@ class EC2ServiceTest(unittest.TestCase):
response = self._service.get_public_keys()
expected = [
mock.call('%s/meta-data/public-keys' %
self._service._metadata_version),
self._service._metadata_version,
decode=True),
mock.call('%(version)s/meta-data/public-keys/%('
'idx)s/openssh-key' %
{'version': self._service._metadata_version,
'idx': 'key'})]
'idx': 'key'}, decode=True)]
self.assertEqual(expected, mock_get_cache_data.call_args_list)
self.assertEqual(['fake key'], response)

View File

@ -148,7 +148,8 @@ class MaaSHttpServiceTest(unittest.TestCase):
response = self._maasservice.get_host_name()
mock_get_cache_data.assert_called_once_with(
'%s/meta-data/local-hostname' %
self._maasservice._metadata_version)
self._maasservice._metadata_version,
decode=True)
self.assertEqual(mock_get_cache_data.return_value, response)
@mock.patch("cloudbaseinit.metadata.services.maasservice.MaaSHttpService"
@ -156,13 +157,10 @@ class MaaSHttpServiceTest(unittest.TestCase):
def test_get_instance_id(self, mock_get_cache_data):
response = self._maasservice.get_instance_id()
mock_get_cache_data.assert_called_once_with(
'%s/meta-data/instance-id' % self._maasservice._metadata_version)
'%s/meta-data/instance-id' % self._maasservice._metadata_version,
decode=True)
self.assertEqual(mock_get_cache_data.return_value, response)
def test_get_list_from_text(self):
response = self._maasservice._get_list_from_text('fake:text', ':')
self.assertEqual(['fake:', 'text:'], response)
@mock.patch("cloudbaseinit.metadata.services.maasservice.MaaSHttpService"
"._get_cache_data")
def test_get_public_keys(self, mock_get_cache_data):
@ -174,21 +172,26 @@ class MaaSHttpServiceTest(unittest.TestCase):
mock_get_cache_data.return_value = public_key
response = self._maasservice.get_public_keys()
mock_get_cache_data.assert_called_with(
'%s/meta-data/public-keys' % self._maasservice._metadata_version)
'%s/meta-data/public-keys' % self._maasservice._metadata_version,
decode=True)
self.assertEqual(public_keys, response)
@mock.patch("cloudbaseinit.metadata.services.maasservice.MaaSHttpService"
"._get_list_from_text")
@mock.patch("cloudbaseinit.metadata.services.maasservice.MaaSHttpService"
"._get_cache_data")
def test_get_client_auth_certs(self, mock_get_cache_data,
mock_get_list_from_text):
def test_get_client_auth_certs(self, mock_get_cache_data):
certs = [
"{begin}\n{cert}\n{end}".format(
begin=x509constants.PEM_HEADER,
end=x509constants.PEM_FOOTER,
cert=cert)
for cert in ("first cert", "second cert")
]
mock_get_cache_data.return_value = "\n".join(certs) + "\n"
response = self._maasservice.get_client_auth_certs()
mock_get_cache_data.assert_called_with(
'%s/meta-data/x509' % self._maasservice._metadata_version)
mock_get_list_from_text.assert_called_once_with(
mock_get_cache_data(), "%s\n" % x509constants.PEM_FOOTER)
self.assertEqual(mock_get_list_from_text.return_value, response)
'%s/meta-data/x509' % self._maasservice._metadata_version,
decode=True)
self.assertEqual(certs, response)
@mock.patch("cloudbaseinit.metadata.services.maasservice.MaaSHttpService"
"._get_cache_data")

View File

@ -0,0 +1,54 @@
# Copyright 2014 Cloudbase Solutions Srl
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import os
import tempfile
import unittest
from cloudbaseinit.tests import testutils
from cloudbaseinit.utils import encoding
class TestEncoding(unittest.TestCase):
def test_get_as_string(self):
content_map = [
("data", "data"),
(b"data", "data"),
("data".encode(), "data"),
("data".encode("utf-16"), None)
]
with testutils.LogSnatcher("cloudbaseinit.utils.encoding") as snatch:
for content, expect in content_map:
self.assertEqual(expect, encoding.get_as_string(content))
self.assertIn("couldn't decode", snatch.output[0].lower())
def test_write_file(self):
mode_map = [
(("w", "r"), "my test\ndata\n\n", False),
(("wb", "rb"), "\r\n".join((chr(x) for x in
(32, 125, 0))).encode(), False),
(("wb", "rb"), "my test\ndata\n\n", True)
]
with testutils.create_tempdir() as temp:
fd, path = tempfile.mkstemp(dir=temp)
os.close(fd)
for (write, read), data, encode in mode_map:
encoding.write_file(path, data, mode=write)
with open(path, read) as stream:
content = stream.read()
if encode:
data = data.encode()
self.assertEqual(data, content)

View File

@ -285,7 +285,6 @@ class CryptoAPICertManagerTests(unittest.TestCase):
fake_cert_data += x509constants.PEM_HEADER + '\n'
fake_cert_data += 'fake cert' + '\n'
fake_cert_data += x509constants.PEM_FOOTER
fake_cert_data = fake_cert_data.encode()
response = self._x509_manager._get_cert_base64(fake_cert_data)
self.assertEqual('fake cert', response)

View File

@ -14,6 +14,11 @@
import six
from cloudbaseinit.openstack.common import log as logging
LOG = logging.getLogger(__name__)
def get_as_string(value):
if value is None or isinstance(value, six.text_type):
@ -22,7 +27,9 @@ def get_as_string(value):
try:
return value.decode()
except Exception:
pass
# This is important, because None will be returned,
# but not that serious to raise an exception.
LOG.error("Couldn't decode: %r", value)
def write_file(target_path, data, mode='wb'):
@ -31,13 +38,3 @@ def write_file(target_path, data, mode='wb'):
with open(target_path, mode) as f:
f.write(data)
def read_file(target_path, mode='rb'):
with open(target_path, mode) as f:
data = f.read()
if 'b' in mode:
data = data.decode()
return data

View File

@ -20,7 +20,6 @@ import uuid
import six
from cloudbaseinit.utils import encoding
from cloudbaseinit.utils.windows import cryptoapi
from cloudbaseinit.utils import x509constants
@ -205,13 +204,17 @@ class CryptoAPICertManager(object):
free(subject_encoded)
def _get_cert_base64(self, cert_data):
base64_cert_data = encoding.get_as_string(cert_data)
if base64_cert_data.startswith(x509constants.PEM_HEADER):
base64_cert_data = base64_cert_data[len(x509constants.PEM_HEADER):]
if base64_cert_data.endswith(x509constants.PEM_FOOTER):
base64_cert_data = base64_cert_data[:len(base64_cert_data) -
len(x509constants.PEM_FOOTER)]
return base64_cert_data.replace("\n", "")
"""Remove certificate header and footer and also new lines."""
# It's assured that the certificate is already a string.
removal = [
x509constants.PEM_HEADER,
x509constants.PEM_FOOTER,
"\r",
"\n"
]
for remove in removal:
cert_data = cert_data.replace(remove, "")
return cert_data
def import_cert(self, cert_data, machine_keyset=True,
store_name=STORE_NAME_MY):