From 5cd9947a3c92dccc76966d7fe8698cf24839f0c8 Mon Sep 17 00:00:00 2001 From: Claudiu Popa Date: Wed, 17 Jun 2015 18:58:22 +0300 Subject: [PATCH] Improve WindowsUtils._get_ipv4_forward_table This patch improves _get_ipv4_forward_table by doing only one call to iphlpapi.GetIpForwardTable. Previously two calls were made, where the second call should have been called only when the first one failed with INSUFFICIENT_BUFFER. This patch also adds tests for _get_ipv4_forward_table, which was previously untested. Change-Id: I5728948557e749d0e0005a9d90c12e4093dc3307 --- cloudbaseinit/osutils/windows.py | 92 ++++++------ cloudbaseinit/tests/osutils/test_windows.py | 151 ++++++++++++++++++++ 2 files changed, 200 insertions(+), 43 deletions(-) diff --git a/cloudbaseinit/osutils/windows.py b/cloudbaseinit/osutils/windows.py index 8dde6677..910af416 100644 --- a/cloudbaseinit/osutils/windows.py +++ b/cloudbaseinit/osutils/windows.py @@ -12,7 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. - +import contextlib import ctypes from ctypes import wintypes import os @@ -32,7 +32,6 @@ import wmi from cloudbaseinit import exception from cloudbaseinit.openstack.common import log as logging from cloudbaseinit.osutils import base -from cloudbaseinit.utils import encoding from cloudbaseinit.utils.windows import network from cloudbaseinit.utils.windows import privilege from cloudbaseinit.utils.windows import timezone @@ -739,62 +738,69 @@ class WindowsUtils(base.BaseOSUtils): else: return (None, None) - def _get_ipv4_routing_table(self): - routing_table = [] - - heap = kernel32.GetProcessHeap() - - size = wintypes.ULONG(ctypes.sizeof(Win32_MIB_IPFORWARDTABLE)) - p = kernel32.HeapAlloc(heap, 0, ctypes.c_size_t(size.value)) - if not p: + @staticmethod + def _heap_alloc(heap, size): + table_mem = kernel32.HeapAlloc(heap, 0, ctypes.c_size_t(size.value)) + if not table_mem: raise exception.CloudbaseInitException( 'Unable to allocate memory for the IP forward table') + return table_mem + + @contextlib.contextmanager + def _get_forward_table(self): + heap = kernel32.GetProcessHeap() + forward_table_size = ctypes.sizeof(Win32_MIB_IPFORWARDTABLE) + size = wintypes.ULONG(forward_table_size) + table_mem = self._heap_alloc(heap, size) + p_forward_table = ctypes.cast( - p, ctypes.POINTER(Win32_MIB_IPFORWARDTABLE)) + table_mem, ctypes.POINTER(Win32_MIB_IPFORWARDTABLE)) try: err = iphlpapi.GetIpForwardTable(p_forward_table, ctypes.byref(size), 0) if err == self.ERROR_INSUFFICIENT_BUFFER: kernel32.HeapFree(heap, 0, p_forward_table) - p = kernel32.HeapAlloc(heap, 0, ctypes.c_size_t(size.value)) - if not p: - raise exception.CloudbaseInitException( - 'Unable to allocate memory for the IP forward table') + table_mem = self._heap_alloc(heap, size) p_forward_table = ctypes.cast( - p, ctypes.POINTER(Win32_MIB_IPFORWARDTABLE)) + table_mem, + ctypes.POINTER(Win32_MIB_IPFORWARDTABLE)) + err = iphlpapi.GetIpForwardTable(p_forward_table, + ctypes.byref(size), 0) - err = iphlpapi.GetIpForwardTable(p_forward_table, - ctypes.byref(size), 0) - if err != self.ERROR_NO_DATA: - if err: - raise exception.CloudbaseInitException( - 'Unable to get IP forward table. Error: %s' % err) + if err and err != kernel32.ERROR_NO_DATA: + raise exception.CloudbaseInitException( + 'Unable to get IP forward table. Error: %s' % err) - forward_table = p_forward_table.contents - table = ctypes.cast( - ctypes.addressof(forward_table.table), - ctypes.POINTER(Win32_MIB_IPFORWARDROW * - forward_table.dwNumEntries)).contents - - i = 0 - while i < forward_table.dwNumEntries: - row = table[i] - routing_table.append(( - encoding.get_as_string(Ws2_32.inet_ntoa( - row.dwForwardDest)), - encoding.get_as_string(Ws2_32.inet_ntoa( - row.dwForwardMask)), - encoding.get_as_string(Ws2_32.inet_ntoa( - row.dwForwardNextHop)), - row.dwForwardIfIndex, - row.dwForwardMetric1)) - i += 1 - - return routing_table + yield p_forward_table finally: kernel32.HeapFree(heap, 0, p_forward_table) + def _get_ipv4_routing_table(self): + routing_table = [] + with self._get_forward_table() as p_forward_table: + forward_table = p_forward_table.contents + table = ctypes.cast( + ctypes.addressof(forward_table.table), + ctypes.POINTER(Win32_MIB_IPFORWARDROW * + forward_table.dwNumEntries)).contents + + for row in table: + destination = Ws2_32.inet_ntoa( + row.dwForwardDest).decode() + netmask = Ws2_32.inet_ntoa( + row.dwForwardMask).decode() + gateway = Ws2_32.inet_ntoa( + row.dwForwardNextHop).decode() + routing_table.append(( + destination, + netmask, + gateway, + row.dwForwardIfIndex, + row.dwForwardMetric1)) + + return routing_table + def check_static_route_exists(self, destination): return len([r for r in self._get_ipv4_routing_table() if r[0] == destination]) > 0 diff --git a/cloudbaseinit/tests/osutils/test_windows.py b/cloudbaseinit/tests/osutils/test_windows.py index c4f40cc4..cc1e8625 100644 --- a/cloudbaseinit/tests/osutils/test_windows.py +++ b/cloudbaseinit/tests/osutils/test_windows.py @@ -83,6 +83,8 @@ class TestWindowsUtils(testutils.CloudbaseInitTestBase): self.windows_utils.WindowsError = mock.MagicMock() self._winutils = self.windows_utils.WindowsUtils() + self._kernel32 = self._windll_mock.kernel32 + self._iphlpapi = self._windll_mock.iphlpapi def tearDown(self): self._module_patcher.stop() @@ -1538,3 +1540,152 @@ class TestWindowsUtils(testutils.CloudbaseInitTestBase): mock.sentinel.windows_timezone) mock_timezone.Timezone.return_value.set.assert_called_once_with( self._winutils) + + def _test__heap_alloc(self, fail): + mock_heap = mock.Mock() + mock_size = mock.Mock() + + if fail: + self._kernel32.HeapAlloc.return_value = None + + with self.assertRaises(exception.CloudbaseInitException) as cm: + self._winutils._heap_alloc(mock_heap, mock_size) + + self.assertEqual('Unable to allocate memory for the IP ' + 'forward table', + str(cm.exception)) + else: + result = self._winutils._heap_alloc(mock_heap, mock_size) + self.assertEqual(self._kernel32.HeapAlloc.return_value, result) + + self._kernel32.HeapAlloc.assert_called_once_with( + mock_heap, 0, self._ctypes_mock.c_size_t(mock_size.value)) + + def test__heap_alloc_error(self): + self._test__heap_alloc(fail=True) + + def test__heap_alloc_no_error(self): + self._test__heap_alloc(fail=False) + + def test__get_forward_table_no_memory(self): + self._winutils._heap_alloc = mock.Mock() + error_msg = 'Unable to allocate memory for the IP forward table' + exc = exception.CloudbaseInitException(error_msg) + self._winutils._heap_alloc.side_effect = exc + + with self.assertRaises(exception.CloudbaseInitException) as cm: + with self._winutils._get_forward_table(): + pass + + self.assertEqual(error_msg, str(cm.exception)) + self._winutils._heap_alloc.assert_called_once_with( + self._kernel32.GetProcessHeap.return_value, + self._ctypes_mock.wintypes.ULONG.return_value) + + def test__get_forward_table_insufficient_buffer_no_memory(self): + self._kernel32.HeapAlloc.side_effect = (mock.sentinel.table_mem, None) + self._iphlpapi.GetIpForwardTable.return_value = ( + self._winutils.ERROR_INSUFFICIENT_BUFFER) + + with self.assertRaises(exception.CloudbaseInitException): + with self._winutils._get_forward_table(): + pass + + table = self._ctypes_mock.cast.return_value + self._iphlpapi.GetIpForwardTable.assert_called_once_with( + table, + self._ctypes_mock.byref.return_value, 0) + heap_calls = [ + mock.call(self._kernel32.GetProcessHeap.return_value, 0, table), + mock.call(self._kernel32.GetProcessHeap.return_value, 0, table) + ] + self.assertEqual(heap_calls, self._kernel32.HeapFree.mock_calls) + + def _test__get_forward_table(self, reallocation=False, + insufficient_buffer=False, + fail=False): + if fail: + with self.assertRaises(exception.CloudbaseInitException) as cm: + with self._winutils._get_forward_table(): + pass + + msg = ('Unable to get IP forward table. Error: %s' + % mock.sentinel.error) + self.assertEqual(msg, str(cm.exception)) + else: + with self._winutils._get_forward_table() as table: + pass + pointer = self._ctypes_mock.POINTER( + self._iphlpapi.Win32_MIB_IPFORWARDTABLE) + expected_forward_table = self._ctypes_mock.cast( + self._kernel32.HeapAlloc.return_value, pointer) + self.assertEqual(expected_forward_table, table) + + heap_calls = [ + mock.call(self._kernel32.GetProcessHeap.return_value, 0, + self._ctypes_mock.cast.return_value) + ] + forward_calls = [ + mock.call(self._ctypes_mock.cast.return_value, + self._ctypes_mock.byref.return_value, 0), + ] + if insufficient_buffer: + # We expect two calls for GetIpForwardTable + forward_calls.append(forward_calls[0]) + if reallocation: + heap_calls.append(heap_calls[0]) + self.assertEqual(heap_calls, self._kernel32.HeapFree.mock_calls) + self.assertEqual(forward_calls, + self._iphlpapi.GetIpForwardTable.mock_calls) + + def test__get_forward_table_sufficient_buffer(self): + self._iphlpapi.GetIpForwardTable.return_value = None + self._test__get_forward_table() + + def test__get_forward_table_insufficient_buffer_reallocate(self): + self._kernel32.HeapAlloc.side_effect = ( + mock.sentinel.table_mem, mock.sentinel.table_mem) + self._iphlpapi.GetIpForwardTable.side_effect = ( + self._winutils.ERROR_INSUFFICIENT_BUFFER, None) + + self._test__get_forward_table(reallocation=True, + insufficient_buffer=True) + + def test__get_forward_table_insufficient_buffer_other_error(self): + self._kernel32.HeapAlloc.side_effect = ( + mock.sentinel.table_mem, mock.sentinel.table_mem) + self._iphlpapi.GetIpForwardTable.side_effect = ( + self._winutils.ERROR_INSUFFICIENT_BUFFER, mock.sentinel.error) + + self._test__get_forward_table(reallocation=True, + insufficient_buffer=True, + fail=True) + + @mock.patch('cloudbaseinit.osutils.windows.WindowsUtils.' + '_get_forward_table') + def test_routes(self, mock_forward_table): + def _same(arg): + return arg._mock_name.encode() + + route = mock.MagicMock() + mock_cast_result = mock.Mock() + mock_cast_result.contents = [route] + self._ctypes_mock.cast.return_value = mock_cast_result + self.windows_utils.Ws2_32.inet_ntoa.side_effect = _same + route.dwForwardIfIndex = 'dwForwardIfIndex' + route.dwForwardProto = 'dwForwardProto' + route.dwForwardMetric1 = 'dwForwardMetric1' + routes = self._winutils._get_ipv4_routing_table() + + mock_forward_table.assert_called_once_with() + enter = mock_forward_table.return_value.__enter__ + enter.assert_called_once_with() + exit_ = mock_forward_table.return_value.__exit__ + exit_.assert_called_once_with(None, None, None) + self.assertEqual(1, len(routes)) + given_route = routes[0] + self.assertEqual('dwForwardDest', given_route[0]) + self.assertEqual('dwForwardMask', given_route[1]) + self.assertEqual('dwForwardNextHop', given_route[2]) + self.assertEqual('dwForwardIfIndex', given_route[3]) + self.assertEqual('dwForwardMetric1', given_route[4])