diff --git a/oslo_vmware/rw_handles.py b/oslo_vmware/rw_handles.py index 9746e1f..b56c982 100644 --- a/oslo_vmware/rw_handles.py +++ b/oslo_vmware/rw_handles.py @@ -62,16 +62,47 @@ class FileHandle(object): self._last_logged_progress = 0 self._last_progress_udpate = 0 - def _create_read_connection(self, url, cookies=None, cacerts=False): + def _create_connection(self, url, method, cacerts=False, + ssl_thumbprint=None): + _urlparse = urlparse.urlparse(url) + scheme, netloc, path, params, query, fragment = _urlparse + if scheme == 'http': + conn = httplib.HTTPConnection(netloc) + elif scheme == 'https': + conn = httplib.HTTPSConnection(netloc) + cert_reqs = None + + # cacerts can be either True or False or contain + # actual certificates. If it is a boolean, then + # we need to set cert_reqs and clear the cacerts + if isinstance(cacerts, bool): + if cacerts: + cert_reqs = ssl.CERT_REQUIRED + else: + cert_reqs = ssl.CERT_NONE + cacerts = None + conn.set_cert(ca_certs=cacerts, cert_reqs=cert_reqs, + assert_fingerprint=ssl_thumbprint) + else: + excep_msg = _("Invalid scheme: %s.") % scheme + LOG.error(excep_msg) + raise ValueError(excep_msg) + + if query: + path = path + '?' + query + conn.putrequest(method, path) + return conn + + def _create_read_connection(self, url, cookies=None, cacerts=False, + ssl_thumbprint=None): LOG.debug("Opening URL: %s for reading.", url) try: - headers = {'User-Agent': USER_AGENT} - if cookies: - headers.update({'Cookie': - self._build_vim_cookie_header(cookies)}) - response = requests.get(url, headers=headers, stream=True, - verify=cacerts) - return response.raw + conn = self._create_connection(url, 'GET', cacerts, ssl_thumbprint) + vim_cookie = self._build_vim_cookie_header(cookies) + conn.putheader('User-Agent', USER_AGENT) + conn.putheader('Cookie', vim_cookie) + conn.endheaders() + return conn.getresponse() except Exception as excep: # TODO(vbala) We need to catch and raise specific exceptions # related to connection problems, invalid request and invalid @@ -86,41 +117,15 @@ class FileHandle(object): cookies=None, overwrite=None, content_type=None, - cacerts=False): + cacerts=False, + ssl_thumbprint=None): """Create HTTP connection to write to VMDK file.""" LOG.debug("Creating HTTP connection to write to file with " "size = %(file_size)d and URL = %(url)s.", {'file_size': file_size, 'url': url}) - _urlparse = urlparse.urlparse(url) - scheme, netloc, path, params, query, fragment = _urlparse - try: - if scheme == 'http': - conn = httplib.HTTPConnection(netloc) - elif scheme == 'https': - conn = httplib.HTTPSConnection(netloc) - cert_reqs = None - - # cacerts can be either True or False or contain - # actual certificates. If it is a boolean, then - # we need to set cert_reqs and clear the cacerts - if isinstance(cacerts, bool): - if cacerts: - cert_reqs = ssl.CERT_REQUIRED - else: - cert_reqs = ssl.CERT_NONE - cacerts = None - - conn.set_cert(ca_certs=cacerts, cert_reqs=cert_reqs) - else: - excep_msg = _("Invalid scheme: %s.") % scheme - LOG.error(excep_msg) - raise ValueError(excep_msg) - - if query: - path = path + '?' + query - + conn = self._create_connection(url, 'PUT', cacerts, ssl_thumbprint) headers = {'User-Agent': USER_AGENT} if file_size: headers.update({'Content-Length': str(file_size)}) @@ -131,8 +136,6 @@ class FileHandle(object): self._build_vim_cookie_header(cookies)}) if content_type: headers.update({'Content-Type': content_type}) - - conn.putrequest('PUT', path) for key, value in six.iteritems(headers): conn.putheader(key, value) conn.endheaders() @@ -213,16 +216,18 @@ class FileHandle(object): def _find_vmdk_url(self, lease_info, host, port): """Find the URL corresponding to a VMDK file in lease info.""" url = None + ssl_thumbprint = None for deviceUrl in lease_info.deviceUrl: if deviceUrl.disk: url = self._fix_esx_url(deviceUrl.url, host, port) + ssl_thumbprint = deviceUrl.sslThumbprint break if not url: excep_msg = _("Could not retrieve VMDK URL from lease info.") LOG.error(excep_msg) raise exceptions.VimException(excep_msg) LOG.debug("Found VMDK URL: %s from lease info.", url) - return url + return url, ssl_thumbprint def _log_progress(self, progress): """Log data transfer progress.""" @@ -338,7 +343,7 @@ class VmdkWriteHandle(FileHandle): 'info') # Find VMDK URL where data is to be written - self._url = self._find_vmdk_url(lease_info, host, port) + self._url, thumbprint = self._find_vmdk_url(lease_info, host, port) self._vm_ref = lease_info.entity cookies = session.vim.client.options.transport.cookiejar @@ -349,7 +354,7 @@ class VmdkWriteHandle(FileHandle): cookies=cookies, overwrite='t', content_type=octet_stream, - cacerts=session._cacert) + ssl_thumbprint=thumbprint) FileHandle.__init__(self, self._conn) def get_imported_vm(self): @@ -499,12 +504,11 @@ class VmdkReadHandle(FileHandle): 'info') # find URL of the VMDK file to be read and open connection - self._url = self._find_vmdk_url(lease_info, host, port) + self._url, thumbprint = self._find_vmdk_url(lease_info, host, port) cookies = session.vim.client.options.transport.cookiejar - cacerts = session.vim.client.options.transport.verify self._conn = self._create_read_connection(self._url, cookies=cookies, - cacerts=cacerts) + ssl_thumbprint=thumbprint) FileHandle.__init__(self, self._conn) def _create_and_wait_for_lease(self, session, vm_ref): diff --git a/oslo_vmware/tests/test_rw_handles.py b/oslo_vmware/tests/test_rw_handles.py index f6c5623..163fc5a 100644 --- a/oslo_vmware/tests/test_rw_handles.py +++ b/oslo_vmware/tests/test_rw_handles.py @@ -24,6 +24,7 @@ from oslo_vmware import exceptions from oslo_vmware import rw_handles from oslo_vmware.tests import base from oslo_vmware import vim_util +from urllib3 import connection as httplib class FileHandleTest(base.TestCase): @@ -41,15 +42,23 @@ class FileHandleTest(base.TestCase): device_url_1 = mock.Mock() device_url_1.disk = True device_url_1.url = 'https://*/ds1/vm1.vmdk' + device_url_1.sslThumbprint = '11:22:33:44:55' lease_info = mock.Mock() lease_info.deviceUrl = [device_url_0, device_url_1] host = '10.1.2.3' port = 443 exp_url = 'https://%s:%d/ds1/vm1.vmdk' % (host, port) vmw_http_file = rw_handles.FileHandle(None) - self.assertEqual(exp_url, vmw_http_file._find_vmdk_url(lease_info, - host, - port)) + url, thumbprint = vmw_http_file._find_vmdk_url(lease_info, host, port) + self.assertEqual(exp_url, url) + self.assertEqual('11:22:33:44:55', thumbprint) + + def test_create_connection(self): + handle = rw_handles.FileHandle(None) + conn = handle._create_connection('http://fira', 'GET') + self.assertIsInstance(conn, httplib.HTTPConnection) + conn = handle._create_connection('https://fira', 'GET') + self.assertIsInstance(conn, httplib.HTTPSConnection) class FileWriteHandleTest(base.TestCase): @@ -182,12 +191,15 @@ class VmdkReadHandleTest(base.TestCase): def setUp(self): super(VmdkReadHandleTest, self).setUp() - - send_patcher = mock.patch('requests.sessions.Session.send') - self.addCleanup(send_patcher.stop) - send_mock = send_patcher.start() - self._response = mock.Mock() - send_mock.return_value = self._response + self._resp = mock.Mock() + self._resp.read.return_value = 'fake-data' + self._conn = mock.Mock() + self._conn.getresponse.return_value = self._resp + patcher = mock.patch( + 'urllib3.connection.HTTPConnection') + self.addCleanup(patcher.stop) + HTTPConnectionMock = patcher.start() + HTTPConnectionMock.return_value = self._conn def _create_mock_session(self, disk=True, progress=-1): device_url = mock.Mock() @@ -227,25 +239,22 @@ class VmdkReadHandleTest(base.TestCase): def test_read(self): chunk_size = rw_handles.READ_CHUNKSIZE session = self._create_mock_session() - self._response.raw.read.return_value = [1] * chunk_size handle = rw_handles.VmdkReadHandle(session, '10.1.2.3', 443, 'vm-1', '[ds] disk1.vmdk', chunk_size * 10) - handle.read(chunk_size) - self.assertEqual(chunk_size, handle._bytes_read) - self._response.raw.read.assert_called_once_with(chunk_size) + data = handle.read(chunk_size) + self.assertEqual('fake-data', data) def test_update_progress(self): - chunk_size = rw_handles.READ_CHUNKSIZE + chunk_size = len('fake-data') vmdk_size = chunk_size * 10 session = self._create_mock_session(True, 10) - self._response.raw.read.return_value = [1] * chunk_size handle = rw_handles.VmdkReadHandle(session, '10.1.2.3', 443, 'vm-1', '[ds] disk1.vmdk', vmdk_size) - handle.read(chunk_size) + data = handle.read(chunk_size) handle.update_progress() - self._response.raw.read.assert_called_once_with(chunk_size) + self.assertEqual('fake-data', data) def test_update_progress_with_error(self): session = self._create_mock_session(True, 10)