Update SmartOS data source to use v2 metadata.

v2 metadata is described at 
 http://eng.joyent.com/mdata/protocol.html
This commit is contained in:
Scott Moser 2015-03-25 14:32:08 -04:00
commit b39cdfb9e7
3 changed files with 237 additions and 85 deletions

View File

@ -27,6 +27,7 @@
- CloudStack: support fetching password from virtual router [Daniel Watkins]
(LP: #1422388)
- readurl, read_file_or_url returns bytes, user must convert as necessary
- SmartOS: use v2 metadata service (LP: #1436417) [Daniel Watkins]
0.7.6:
- open 0.7.6
- Enable vendordata on CloudSigma datasource (LP: #1303986)

View File

@ -29,9 +29,12 @@
# http://us-east.manta.joyent.com/jmc/public/mdata/datadict.html
# Comments with "@datadictionary" are snippets of the definition
import base64
import binascii
import contextlib
import os
import random
import re
import serial
from cloudinit import log as logging
@ -301,6 +304,65 @@ def get_serial(seed_device, seed_timeout):
return ser
class JoyentMetadataFetchException(Exception):
pass
class JoyentMetadataClient(object):
"""
A client implementing v2 of the Joyent Metadata Protocol Specification.
The full specification can be found at
http://eng.joyent.com/mdata/protocol.html
"""
line_regex = re.compile(
r'V2 (?P<length>\d+) (?P<checksum>[0-9a-f]+)'
r' (?P<body>(?P<request_id>[0-9a-f]+) (?P<status>SUCCESS|NOTFOUND)'
r'( (?P<payload>.+))?)')
def __init__(self, serial):
self.serial = serial
def _checksum(self, body):
return '{0:08x}'.format(
binascii.crc32(body.encode('utf-8')) & 0xffffffff)
def _get_value_from_frame(self, expected_request_id, frame):
frame_data = self.line_regex.match(frame).groupdict()
if int(frame_data['length']) != len(frame_data['body']):
raise JoyentMetadataFetchException(
'Incorrect frame length given ({0} != {1}).'.format(
frame_data['length'], len(frame_data['body'])))
expected_checksum = self._checksum(frame_data['body'])
if frame_data['checksum'] != expected_checksum:
raise JoyentMetadataFetchException(
'Invalid checksum (expected: {0}; got {1}).'.format(
expected_checksum, frame_data['checksum']))
if frame_data['request_id'] != expected_request_id:
raise JoyentMetadataFetchException(
'Request ID mismatch (expected: {0}; got {1}).'.format(
expected_request_id, frame_data['request_id']))
if not frame_data.get('payload', None):
LOG.debug('No value found.')
return None
value = util.b64d(frame_data['payload'])
LOG.debug('Value "%s" found.', value)
return value
def get_metadata(self, metadata_key):
LOG.debug('Fetching metadata key "%s"...', metadata_key)
request_id = '{0:08x}'.format(random.randint(0, 0xffffffff))
message_body = '{0} GET {1}'.format(request_id,
util.b64e(metadata_key))
msg = 'V2 {0} {1} {2}\n'.format(
len(message_body), self._checksum(message_body), message_body)
LOG.debug('Writing "%s" to serial port.', msg)
self.serial.write(msg.encode('ascii'))
response = self.serial.readline().decode('ascii')
LOG.debug('Read "%s" from serial port.', response)
return self._get_value_from_frame(request_id, response)
def query_data(noun, seed_device, seed_timeout, strip=False, default=None,
b64=None):
"""Makes a request to via the serial console via "GET <NOUN>"
@ -314,34 +376,20 @@ def query_data(noun, seed_device, seed_timeout, strip=False, default=None,
encoded, so this method relies on being told if the data is base64 or
not.
"""
if not noun:
return False
ser = get_serial(seed_device, seed_timeout)
request_line = "GET %s\n" % noun.rstrip()
ser.write(request_line.encode('ascii'))
status = str(ser.readline()).rstrip()
response = []
eom_found = False
with contextlib.closing(get_serial(seed_device, seed_timeout)) as ser:
client = JoyentMetadataClient(ser)
response = client.get_metadata(noun)
if 'SUCCESS' not in status:
ser.close()
if response is None:
return default
while not eom_found:
m = ser.readline().decode('ascii')
if m.rstrip() == ".":
eom_found = True
else:
response.append(m)
ser.close()
if b64 is None:
b64 = query_data('b64-%s' % noun, seed_device=seed_device,
seed_timeout=seed_timeout, b64=False,
default=False, strip=True)
seed_timeout=seed_timeout, b64=False,
default=False, strip=True)
b64 = util.is_true(b64)
resp = None

View File

@ -24,20 +24,30 @@
from __future__ import print_function
from cloudinit import helpers as c_helpers
from cloudinit.sources import DataSourceSmartOS
from cloudinit.util import b64e
from .. import helpers
import os
import os.path
import re
import shutil
import tempfile
import stat
import tempfile
import uuid
from binascii import crc32
import serial
import six
import six
from cloudinit import helpers as c_helpers
from cloudinit.sources import DataSourceSmartOS
from cloudinit.util import b64e
from .. import helpers
try:
from unittest import mock
except ImportError:
import mock
MOCK_RETURNS = {
'hostname': 'test-host',
@ -56,63 +66,15 @@ MOCK_RETURNS = {
DMI_DATA_RETURN = (str(uuid.uuid4()), 'smartdc')
class MockSerial(object):
"""Fake a serial terminal for testing the code that
interfaces with the serial"""
def get_mock_client(mockdata):
class MockMetadataClient(object):
port = None
def __init__(self, serial):
pass
def __init__(self, mockdata):
self.last = None
self.last = None
self.new = True
self.count = 0
self.mocked_out = []
self.mockdata = mockdata
def open(self):
return True
def close(self):
return True
def isOpen(self):
return True
def write(self, line):
if not isinstance(line, six.binary_type):
raise TypeError("Should be writing binary lines.")
line = line.decode('ascii').replace('GET ', '')
self.last = line.rstrip()
def readline(self):
if self.new:
self.new = False
if self.last in self.mockdata:
line = 'SUCCESS\n'
else:
line = 'NOTFOUND %s\n' % self.last
elif self.last in self.mockdata:
if not self.mocked_out:
self.mocked_out = [x for x in self._format_out()]
if len(self.mocked_out) > self.count:
self.count += 1
line = self.mocked_out[self.count - 1]
return line.encode('ascii')
def _format_out(self):
if self.last in self.mockdata:
_mret = self.mockdata[self.last]
try:
for l in _mret.splitlines():
yield "%s\n" % l.rstrip()
except:
yield "%s\n" % _mret.rstrip()
yield '.'
yield '\n'
def get_metadata(self, metadata_key):
return mockdata.get(metadata_key)
return MockMetadataClient
class TestSmartOSDataSource(helpers.FilesystemMockingTestCase):
@ -160,9 +122,6 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase):
if dmi_data is None:
dmi_data = DMI_DATA_RETURN
def _get_serial(*_):
return MockSerial(mockdata)
def _dmi_data():
return dmi_data
@ -179,7 +138,9 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase):
sys_cfg['datasource']['SmartOS'] = ds_cfg
self.apply_patches([(mod, 'LEGACY_USER_D', self.legacy_user_d)])
self.apply_patches([(mod, 'get_serial', _get_serial)])
self.apply_patches([(mod, 'get_serial', mock.MagicMock())])
self.apply_patches([
(mod, 'JoyentMetadataClient', get_mock_client(mockdata))])
self.apply_patches([(mod, 'dmi_data', _dmi_data)])
self.apply_patches([(os, 'uname', _os_uname)])
self.apply_patches([(mod, 'device_exists', lambda d: True)])
@ -448,6 +409,18 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase):
self.assertEqual(dsrc.device_name_to_device('FOO'),
mydscfg['disk_aliases']['FOO'])
@mock.patch('cloudinit.sources.DataSourceSmartOS.JoyentMetadataClient')
@mock.patch('cloudinit.sources.DataSourceSmartOS.get_serial')
def test_serial_console_closed_on_error(self, get_serial, metadata_client):
class OurException(Exception):
pass
metadata_client.side_effect = OurException
try:
DataSourceSmartOS.query_data('noun', 'device', 0)
except OurException:
pass
self.assertEqual(1, get_serial.return_value.close.call_count)
def apply_patches(patches):
ret = []
@ -458,3 +431,133 @@ def apply_patches(patches):
setattr(ref, name, replace)
ret.append((ref, name, orig))
return ret
class TestJoyentMetadataClient(helpers.FilesystemMockingTestCase):
def setUp(self):
super(TestJoyentMetadataClient, self).setUp()
self.serial = mock.MagicMock(spec=serial.Serial)
self.request_id = 0xabcdef12
self.metadata_value = 'value'
self.response_parts = {
'command': 'SUCCESS',
'crc': 'b5a9ff00',
'length': 17 + len(b64e(self.metadata_value)),
'payload': b64e(self.metadata_value),
'request_id': '{0:08x}'.format(self.request_id),
}
def make_response():
payload = ''
if self.response_parts['payload']:
payload = ' {0}'.format(self.response_parts['payload'])
del self.response_parts['payload']
return (
'V2 {length} {crc} {request_id} {command}{payload}\n'.format(
payload=payload, **self.response_parts).encode('ascii'))
self.serial.readline.side_effect = make_response
self.patched_funcs.enter_context(
mock.patch('cloudinit.sources.DataSourceSmartOS.random.randint',
mock.Mock(return_value=self.request_id)))
def _get_client(self):
return DataSourceSmartOS.JoyentMetadataClient(self.serial)
def assertEndsWith(self, haystack, prefix):
self.assertTrue(haystack.endswith(prefix),
"{0} does not end with '{1}'".format(
repr(haystack), prefix))
def assertStartsWith(self, haystack, prefix):
self.assertTrue(haystack.startswith(prefix),
"{0} does not start with '{1}'".format(
repr(haystack), prefix))
def test_get_metadata_writes_a_single_line(self):
client = self._get_client()
client.get_metadata('some_key')
self.assertEqual(1, self.serial.write.call_count)
written_line = self.serial.write.call_args[0][0]
self.assertEndsWith(written_line, b'\n')
self.assertEqual(1, written_line.count(b'\n'))
def _get_written_line(self, key='some_key'):
client = self._get_client()
client.get_metadata(key)
return self.serial.write.call_args[0][0]
def test_get_metadata_writes_bytes(self):
self.assertIsInstance(self._get_written_line(), six.binary_type)
def test_get_metadata_line_starts_with_v2(self):
self.assertStartsWith(self._get_written_line(), b'V2')
def test_get_metadata_uses_get_command(self):
parts = self._get_written_line().decode('ascii').strip().split(' ')
self.assertEqual('GET', parts[4])
def test_get_metadata_base64_encodes_argument(self):
key = 'my_key'
parts = self._get_written_line(key).decode('ascii').strip().split(' ')
self.assertEqual(b64e(key), parts[5])
def test_get_metadata_calculates_length_correctly(self):
parts = self._get_written_line().decode('ascii').strip().split(' ')
expected_length = len(' '.join(parts[3:]))
self.assertEqual(expected_length, int(parts[1]))
def test_get_metadata_uses_appropriate_request_id(self):
parts = self._get_written_line().decode('ascii').strip().split(' ')
request_id = parts[3]
self.assertEqual(8, len(request_id))
self.assertEqual(request_id, request_id.lower())
def test_get_metadata_uses_random_number_for_request_id(self):
line = self._get_written_line()
request_id = line.decode('ascii').strip().split(' ')[3]
self.assertEqual('{0:08x}'.format(self.request_id), request_id)
def test_get_metadata_checksums_correctly(self):
parts = self._get_written_line().decode('ascii').strip().split(' ')
expected_checksum = '{0:08x}'.format(
crc32(' '.join(parts[3:]).encode('utf-8')) & 0xffffffff)
checksum = parts[2]
self.assertEqual(expected_checksum, checksum)
def test_get_metadata_reads_a_line(self):
client = self._get_client()
client.get_metadata('some_key')
self.assertEqual(1, self.serial.readline.call_count)
def test_get_metadata_returns_valid_value(self):
client = self._get_client()
value = client.get_metadata('some_key')
self.assertEqual(self.metadata_value, value)
def test_get_metadata_throws_exception_for_incorrect_length(self):
self.response_parts['length'] = 0
client = self._get_client()
self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException,
client.get_metadata, 'some_key')
def test_get_metadata_throws_exception_for_incorrect_crc(self):
self.response_parts['crc'] = 'deadbeef'
client = self._get_client()
self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException,
client.get_metadata, 'some_key')
def test_get_metadata_throws_exception_for_request_id_mismatch(self):
self.response_parts['request_id'] = 'deadbeef'
client = self._get_client()
client._checksum = lambda _: self.response_parts['crc']
self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException,
client.get_metadata, 'some_key')
def test_get_metadata_returns_None_if_value_not_found(self):
self.response_parts['payload'] = ''
self.response_parts['command'] = 'NOTFOUND'
self.response_parts['length'] = 17
client = self._get_client()
client._checksum = lambda _: self.response_parts['crc']
self.assertIsNone(client.get_metadata('some_key'))