Merge "Improve WindowsUtils._get_ipv4_forward_table"

This commit is contained in:
Jenkins 2015-06-25 09:56:21 +00:00 committed by Gerrit Code Review
commit a89b4b126d
2 changed files with 200 additions and 43 deletions

View File

@ -12,7 +12,7 @@
# License for the specific language governing permissions and limitations
# under the License.
import contextlib
import ctypes
from ctypes import wintypes
import os
@ -32,7 +32,6 @@ import wmi
from cloudbaseinit import exception
from cloudbaseinit.openstack.common import log as logging
from cloudbaseinit.osutils import base
from cloudbaseinit.utils import encoding
from cloudbaseinit.utils.windows import network
from cloudbaseinit.utils.windows import privilege
from cloudbaseinit.utils.windows import timezone
@ -739,62 +738,69 @@ class WindowsUtils(base.BaseOSUtils):
else:
return (None, None)
def _get_ipv4_routing_table(self):
routing_table = []
heap = kernel32.GetProcessHeap()
size = wintypes.ULONG(ctypes.sizeof(Win32_MIB_IPFORWARDTABLE))
p = kernel32.HeapAlloc(heap, 0, ctypes.c_size_t(size.value))
if not p:
@staticmethod
def _heap_alloc(heap, size):
table_mem = kernel32.HeapAlloc(heap, 0, ctypes.c_size_t(size.value))
if not table_mem:
raise exception.CloudbaseInitException(
'Unable to allocate memory for the IP forward table')
return table_mem
@contextlib.contextmanager
def _get_forward_table(self):
heap = kernel32.GetProcessHeap()
forward_table_size = ctypes.sizeof(Win32_MIB_IPFORWARDTABLE)
size = wintypes.ULONG(forward_table_size)
table_mem = self._heap_alloc(heap, size)
p_forward_table = ctypes.cast(
p, ctypes.POINTER(Win32_MIB_IPFORWARDTABLE))
table_mem, ctypes.POINTER(Win32_MIB_IPFORWARDTABLE))
try:
err = iphlpapi.GetIpForwardTable(p_forward_table,
ctypes.byref(size), 0)
if err == self.ERROR_INSUFFICIENT_BUFFER:
kernel32.HeapFree(heap, 0, p_forward_table)
p = kernel32.HeapAlloc(heap, 0, ctypes.c_size_t(size.value))
if not p:
raise exception.CloudbaseInitException(
'Unable to allocate memory for the IP forward table')
table_mem = self._heap_alloc(heap, size)
p_forward_table = ctypes.cast(
p, ctypes.POINTER(Win32_MIB_IPFORWARDTABLE))
table_mem,
ctypes.POINTER(Win32_MIB_IPFORWARDTABLE))
err = iphlpapi.GetIpForwardTable(p_forward_table,
ctypes.byref(size), 0)
err = iphlpapi.GetIpForwardTable(p_forward_table,
ctypes.byref(size), 0)
if err != self.ERROR_NO_DATA:
if err:
raise exception.CloudbaseInitException(
'Unable to get IP forward table. Error: %s' % err)
if err and err != kernel32.ERROR_NO_DATA:
raise exception.CloudbaseInitException(
'Unable to get IP forward table. Error: %s' % err)
forward_table = p_forward_table.contents
table = ctypes.cast(
ctypes.addressof(forward_table.table),
ctypes.POINTER(Win32_MIB_IPFORWARDROW *
forward_table.dwNumEntries)).contents
i = 0
while i < forward_table.dwNumEntries:
row = table[i]
routing_table.append((
encoding.get_as_string(Ws2_32.inet_ntoa(
row.dwForwardDest)),
encoding.get_as_string(Ws2_32.inet_ntoa(
row.dwForwardMask)),
encoding.get_as_string(Ws2_32.inet_ntoa(
row.dwForwardNextHop)),
row.dwForwardIfIndex,
row.dwForwardMetric1))
i += 1
return routing_table
yield p_forward_table
finally:
kernel32.HeapFree(heap, 0, p_forward_table)
def _get_ipv4_routing_table(self):
routing_table = []
with self._get_forward_table() as p_forward_table:
forward_table = p_forward_table.contents
table = ctypes.cast(
ctypes.addressof(forward_table.table),
ctypes.POINTER(Win32_MIB_IPFORWARDROW *
forward_table.dwNumEntries)).contents
for row in table:
destination = Ws2_32.inet_ntoa(
row.dwForwardDest).decode()
netmask = Ws2_32.inet_ntoa(
row.dwForwardMask).decode()
gateway = Ws2_32.inet_ntoa(
row.dwForwardNextHop).decode()
routing_table.append((
destination,
netmask,
gateway,
row.dwForwardIfIndex,
row.dwForwardMetric1))
return routing_table
def check_static_route_exists(self, destination):
return len([r for r in self._get_ipv4_routing_table()
if r[0] == destination]) > 0

View File

@ -83,6 +83,8 @@ class TestWindowsUtils(testutils.CloudbaseInitTestBase):
self.windows_utils.WindowsError = mock.MagicMock()
self._winutils = self.windows_utils.WindowsUtils()
self._kernel32 = self._windll_mock.kernel32
self._iphlpapi = self._windll_mock.iphlpapi
def tearDown(self):
self._module_patcher.stop()
@ -1538,3 +1540,152 @@ class TestWindowsUtils(testutils.CloudbaseInitTestBase):
mock.sentinel.windows_timezone)
mock_timezone.Timezone.return_value.set.assert_called_once_with(
self._winutils)
def _test__heap_alloc(self, fail):
mock_heap = mock.Mock()
mock_size = mock.Mock()
if fail:
self._kernel32.HeapAlloc.return_value = None
with self.assertRaises(exception.CloudbaseInitException) as cm:
self._winutils._heap_alloc(mock_heap, mock_size)
self.assertEqual('Unable to allocate memory for the IP '
'forward table',
str(cm.exception))
else:
result = self._winutils._heap_alloc(mock_heap, mock_size)
self.assertEqual(self._kernel32.HeapAlloc.return_value, result)
self._kernel32.HeapAlloc.assert_called_once_with(
mock_heap, 0, self._ctypes_mock.c_size_t(mock_size.value))
def test__heap_alloc_error(self):
self._test__heap_alloc(fail=True)
def test__heap_alloc_no_error(self):
self._test__heap_alloc(fail=False)
def test__get_forward_table_no_memory(self):
self._winutils._heap_alloc = mock.Mock()
error_msg = 'Unable to allocate memory for the IP forward table'
exc = exception.CloudbaseInitException(error_msg)
self._winutils._heap_alloc.side_effect = exc
with self.assertRaises(exception.CloudbaseInitException) as cm:
with self._winutils._get_forward_table():
pass
self.assertEqual(error_msg, str(cm.exception))
self._winutils._heap_alloc.assert_called_once_with(
self._kernel32.GetProcessHeap.return_value,
self._ctypes_mock.wintypes.ULONG.return_value)
def test__get_forward_table_insufficient_buffer_no_memory(self):
self._kernel32.HeapAlloc.side_effect = (mock.sentinel.table_mem, None)
self._iphlpapi.GetIpForwardTable.return_value = (
self._winutils.ERROR_INSUFFICIENT_BUFFER)
with self.assertRaises(exception.CloudbaseInitException):
with self._winutils._get_forward_table():
pass
table = self._ctypes_mock.cast.return_value
self._iphlpapi.GetIpForwardTable.assert_called_once_with(
table,
self._ctypes_mock.byref.return_value, 0)
heap_calls = [
mock.call(self._kernel32.GetProcessHeap.return_value, 0, table),
mock.call(self._kernel32.GetProcessHeap.return_value, 0, table)
]
self.assertEqual(heap_calls, self._kernel32.HeapFree.mock_calls)
def _test__get_forward_table(self, reallocation=False,
insufficient_buffer=False,
fail=False):
if fail:
with self.assertRaises(exception.CloudbaseInitException) as cm:
with self._winutils._get_forward_table():
pass
msg = ('Unable to get IP forward table. Error: %s'
% mock.sentinel.error)
self.assertEqual(msg, str(cm.exception))
else:
with self._winutils._get_forward_table() as table:
pass
pointer = self._ctypes_mock.POINTER(
self._iphlpapi.Win32_MIB_IPFORWARDTABLE)
expected_forward_table = self._ctypes_mock.cast(
self._kernel32.HeapAlloc.return_value, pointer)
self.assertEqual(expected_forward_table, table)
heap_calls = [
mock.call(self._kernel32.GetProcessHeap.return_value, 0,
self._ctypes_mock.cast.return_value)
]
forward_calls = [
mock.call(self._ctypes_mock.cast.return_value,
self._ctypes_mock.byref.return_value, 0),
]
if insufficient_buffer:
# We expect two calls for GetIpForwardTable
forward_calls.append(forward_calls[0])
if reallocation:
heap_calls.append(heap_calls[0])
self.assertEqual(heap_calls, self._kernel32.HeapFree.mock_calls)
self.assertEqual(forward_calls,
self._iphlpapi.GetIpForwardTable.mock_calls)
def test__get_forward_table_sufficient_buffer(self):
self._iphlpapi.GetIpForwardTable.return_value = None
self._test__get_forward_table()
def test__get_forward_table_insufficient_buffer_reallocate(self):
self._kernel32.HeapAlloc.side_effect = (
mock.sentinel.table_mem, mock.sentinel.table_mem)
self._iphlpapi.GetIpForwardTable.side_effect = (
self._winutils.ERROR_INSUFFICIENT_BUFFER, None)
self._test__get_forward_table(reallocation=True,
insufficient_buffer=True)
def test__get_forward_table_insufficient_buffer_other_error(self):
self._kernel32.HeapAlloc.side_effect = (
mock.sentinel.table_mem, mock.sentinel.table_mem)
self._iphlpapi.GetIpForwardTable.side_effect = (
self._winutils.ERROR_INSUFFICIENT_BUFFER, mock.sentinel.error)
self._test__get_forward_table(reallocation=True,
insufficient_buffer=True,
fail=True)
@mock.patch('cloudbaseinit.osutils.windows.WindowsUtils.'
'_get_forward_table')
def test_routes(self, mock_forward_table):
def _same(arg):
return arg._mock_name.encode()
route = mock.MagicMock()
mock_cast_result = mock.Mock()
mock_cast_result.contents = [route]
self._ctypes_mock.cast.return_value = mock_cast_result
self.windows_utils.Ws2_32.inet_ntoa.side_effect = _same
route.dwForwardIfIndex = 'dwForwardIfIndex'
route.dwForwardProto = 'dwForwardProto'
route.dwForwardMetric1 = 'dwForwardMetric1'
routes = self._winutils._get_ipv4_routing_table()
mock_forward_table.assert_called_once_with()
enter = mock_forward_table.return_value.__enter__
enter.assert_called_once_with()
exit_ = mock_forward_table.return_value.__exit__
exit_.assert_called_once_with(None, None, None)
self.assertEqual(1, len(routes))
given_route = routes[0]
self.assertEqual('dwForwardDest', given_route[0])
self.assertEqual('dwForwardMask', given_route[1])
self.assertEqual('dwForwardNextHop', given_route[2])
self.assertEqual('dwForwardIfIndex', given_route[3])
self.assertEqual('dwForwardMetric1', given_route[4])