Merge "Improve WindowsUtils._get_ipv4_forward_table"
This commit is contained in:
commit
a89b4b126d
@ -12,7 +12,7 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
|
||||
import contextlib
|
||||
import ctypes
|
||||
from ctypes import wintypes
|
||||
import os
|
||||
@ -32,7 +32,6 @@ import wmi
|
||||
from cloudbaseinit import exception
|
||||
from cloudbaseinit.openstack.common import log as logging
|
||||
from cloudbaseinit.osutils import base
|
||||
from cloudbaseinit.utils import encoding
|
||||
from cloudbaseinit.utils.windows import network
|
||||
from cloudbaseinit.utils.windows import privilege
|
||||
from cloudbaseinit.utils.windows import timezone
|
||||
@ -739,61 +738,68 @@ class WindowsUtils(base.BaseOSUtils):
|
||||
else:
|
||||
return (None, None)
|
||||
|
||||
def _get_ipv4_routing_table(self):
|
||||
routing_table = []
|
||||
|
||||
heap = kernel32.GetProcessHeap()
|
||||
|
||||
size = wintypes.ULONG(ctypes.sizeof(Win32_MIB_IPFORWARDTABLE))
|
||||
p = kernel32.HeapAlloc(heap, 0, ctypes.c_size_t(size.value))
|
||||
if not p:
|
||||
@staticmethod
|
||||
def _heap_alloc(heap, size):
|
||||
table_mem = kernel32.HeapAlloc(heap, 0, ctypes.c_size_t(size.value))
|
||||
if not table_mem:
|
||||
raise exception.CloudbaseInitException(
|
||||
'Unable to allocate memory for the IP forward table')
|
||||
return table_mem
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _get_forward_table(self):
|
||||
heap = kernel32.GetProcessHeap()
|
||||
forward_table_size = ctypes.sizeof(Win32_MIB_IPFORWARDTABLE)
|
||||
size = wintypes.ULONG(forward_table_size)
|
||||
table_mem = self._heap_alloc(heap, size)
|
||||
|
||||
p_forward_table = ctypes.cast(
|
||||
p, ctypes.POINTER(Win32_MIB_IPFORWARDTABLE))
|
||||
table_mem, ctypes.POINTER(Win32_MIB_IPFORWARDTABLE))
|
||||
|
||||
try:
|
||||
err = iphlpapi.GetIpForwardTable(p_forward_table,
|
||||
ctypes.byref(size), 0)
|
||||
if err == self.ERROR_INSUFFICIENT_BUFFER:
|
||||
kernel32.HeapFree(heap, 0, p_forward_table)
|
||||
p = kernel32.HeapAlloc(heap, 0, ctypes.c_size_t(size.value))
|
||||
if not p:
|
||||
raise exception.CloudbaseInitException(
|
||||
'Unable to allocate memory for the IP forward table')
|
||||
table_mem = self._heap_alloc(heap, size)
|
||||
p_forward_table = ctypes.cast(
|
||||
p, ctypes.POINTER(Win32_MIB_IPFORWARDTABLE))
|
||||
|
||||
table_mem,
|
||||
ctypes.POINTER(Win32_MIB_IPFORWARDTABLE))
|
||||
err = iphlpapi.GetIpForwardTable(p_forward_table,
|
||||
ctypes.byref(size), 0)
|
||||
if err != self.ERROR_NO_DATA:
|
||||
if err:
|
||||
|
||||
if err and err != kernel32.ERROR_NO_DATA:
|
||||
raise exception.CloudbaseInitException(
|
||||
'Unable to get IP forward table. Error: %s' % err)
|
||||
|
||||
yield p_forward_table
|
||||
finally:
|
||||
kernel32.HeapFree(heap, 0, p_forward_table)
|
||||
|
||||
def _get_ipv4_routing_table(self):
|
||||
routing_table = []
|
||||
with self._get_forward_table() as p_forward_table:
|
||||
forward_table = p_forward_table.contents
|
||||
table = ctypes.cast(
|
||||
ctypes.addressof(forward_table.table),
|
||||
ctypes.POINTER(Win32_MIB_IPFORWARDROW *
|
||||
forward_table.dwNumEntries)).contents
|
||||
|
||||
i = 0
|
||||
while i < forward_table.dwNumEntries:
|
||||
row = table[i]
|
||||
for row in table:
|
||||
destination = Ws2_32.inet_ntoa(
|
||||
row.dwForwardDest).decode()
|
||||
netmask = Ws2_32.inet_ntoa(
|
||||
row.dwForwardMask).decode()
|
||||
gateway = Ws2_32.inet_ntoa(
|
||||
row.dwForwardNextHop).decode()
|
||||
routing_table.append((
|
||||
encoding.get_as_string(Ws2_32.inet_ntoa(
|
||||
row.dwForwardDest)),
|
||||
encoding.get_as_string(Ws2_32.inet_ntoa(
|
||||
row.dwForwardMask)),
|
||||
encoding.get_as_string(Ws2_32.inet_ntoa(
|
||||
row.dwForwardNextHop)),
|
||||
destination,
|
||||
netmask,
|
||||
gateway,
|
||||
row.dwForwardIfIndex,
|
||||
row.dwForwardMetric1))
|
||||
i += 1
|
||||
|
||||
return routing_table
|
||||
finally:
|
||||
kernel32.HeapFree(heap, 0, p_forward_table)
|
||||
|
||||
def check_static_route_exists(self, destination):
|
||||
return len([r for r in self._get_ipv4_routing_table()
|
||||
|
@ -83,6 +83,8 @@ class TestWindowsUtils(testutils.CloudbaseInitTestBase):
|
||||
self.windows_utils.WindowsError = mock.MagicMock()
|
||||
|
||||
self._winutils = self.windows_utils.WindowsUtils()
|
||||
self._kernel32 = self._windll_mock.kernel32
|
||||
self._iphlpapi = self._windll_mock.iphlpapi
|
||||
|
||||
def tearDown(self):
|
||||
self._module_patcher.stop()
|
||||
@ -1538,3 +1540,152 @@ class TestWindowsUtils(testutils.CloudbaseInitTestBase):
|
||||
mock.sentinel.windows_timezone)
|
||||
mock_timezone.Timezone.return_value.set.assert_called_once_with(
|
||||
self._winutils)
|
||||
|
||||
def _test__heap_alloc(self, fail):
|
||||
mock_heap = mock.Mock()
|
||||
mock_size = mock.Mock()
|
||||
|
||||
if fail:
|
||||
self._kernel32.HeapAlloc.return_value = None
|
||||
|
||||
with self.assertRaises(exception.CloudbaseInitException) as cm:
|
||||
self._winutils._heap_alloc(mock_heap, mock_size)
|
||||
|
||||
self.assertEqual('Unable to allocate memory for the IP '
|
||||
'forward table',
|
||||
str(cm.exception))
|
||||
else:
|
||||
result = self._winutils._heap_alloc(mock_heap, mock_size)
|
||||
self.assertEqual(self._kernel32.HeapAlloc.return_value, result)
|
||||
|
||||
self._kernel32.HeapAlloc.assert_called_once_with(
|
||||
mock_heap, 0, self._ctypes_mock.c_size_t(mock_size.value))
|
||||
|
||||
def test__heap_alloc_error(self):
|
||||
self._test__heap_alloc(fail=True)
|
||||
|
||||
def test__heap_alloc_no_error(self):
|
||||
self._test__heap_alloc(fail=False)
|
||||
|
||||
def test__get_forward_table_no_memory(self):
|
||||
self._winutils._heap_alloc = mock.Mock()
|
||||
error_msg = 'Unable to allocate memory for the IP forward table'
|
||||
exc = exception.CloudbaseInitException(error_msg)
|
||||
self._winutils._heap_alloc.side_effect = exc
|
||||
|
||||
with self.assertRaises(exception.CloudbaseInitException) as cm:
|
||||
with self._winutils._get_forward_table():
|
||||
pass
|
||||
|
||||
self.assertEqual(error_msg, str(cm.exception))
|
||||
self._winutils._heap_alloc.assert_called_once_with(
|
||||
self._kernel32.GetProcessHeap.return_value,
|
||||
self._ctypes_mock.wintypes.ULONG.return_value)
|
||||
|
||||
def test__get_forward_table_insufficient_buffer_no_memory(self):
|
||||
self._kernel32.HeapAlloc.side_effect = (mock.sentinel.table_mem, None)
|
||||
self._iphlpapi.GetIpForwardTable.return_value = (
|
||||
self._winutils.ERROR_INSUFFICIENT_BUFFER)
|
||||
|
||||
with self.assertRaises(exception.CloudbaseInitException):
|
||||
with self._winutils._get_forward_table():
|
||||
pass
|
||||
|
||||
table = self._ctypes_mock.cast.return_value
|
||||
self._iphlpapi.GetIpForwardTable.assert_called_once_with(
|
||||
table,
|
||||
self._ctypes_mock.byref.return_value, 0)
|
||||
heap_calls = [
|
||||
mock.call(self._kernel32.GetProcessHeap.return_value, 0, table),
|
||||
mock.call(self._kernel32.GetProcessHeap.return_value, 0, table)
|
||||
]
|
||||
self.assertEqual(heap_calls, self._kernel32.HeapFree.mock_calls)
|
||||
|
||||
def _test__get_forward_table(self, reallocation=False,
|
||||
insufficient_buffer=False,
|
||||
fail=False):
|
||||
if fail:
|
||||
with self.assertRaises(exception.CloudbaseInitException) as cm:
|
||||
with self._winutils._get_forward_table():
|
||||
pass
|
||||
|
||||
msg = ('Unable to get IP forward table. Error: %s'
|
||||
% mock.sentinel.error)
|
||||
self.assertEqual(msg, str(cm.exception))
|
||||
else:
|
||||
with self._winutils._get_forward_table() as table:
|
||||
pass
|
||||
pointer = self._ctypes_mock.POINTER(
|
||||
self._iphlpapi.Win32_MIB_IPFORWARDTABLE)
|
||||
expected_forward_table = self._ctypes_mock.cast(
|
||||
self._kernel32.HeapAlloc.return_value, pointer)
|
||||
self.assertEqual(expected_forward_table, table)
|
||||
|
||||
heap_calls = [
|
||||
mock.call(self._kernel32.GetProcessHeap.return_value, 0,
|
||||
self._ctypes_mock.cast.return_value)
|
||||
]
|
||||
forward_calls = [
|
||||
mock.call(self._ctypes_mock.cast.return_value,
|
||||
self._ctypes_mock.byref.return_value, 0),
|
||||
]
|
||||
if insufficient_buffer:
|
||||
# We expect two calls for GetIpForwardTable
|
||||
forward_calls.append(forward_calls[0])
|
||||
if reallocation:
|
||||
heap_calls.append(heap_calls[0])
|
||||
self.assertEqual(heap_calls, self._kernel32.HeapFree.mock_calls)
|
||||
self.assertEqual(forward_calls,
|
||||
self._iphlpapi.GetIpForwardTable.mock_calls)
|
||||
|
||||
def test__get_forward_table_sufficient_buffer(self):
|
||||
self._iphlpapi.GetIpForwardTable.return_value = None
|
||||
self._test__get_forward_table()
|
||||
|
||||
def test__get_forward_table_insufficient_buffer_reallocate(self):
|
||||
self._kernel32.HeapAlloc.side_effect = (
|
||||
mock.sentinel.table_mem, mock.sentinel.table_mem)
|
||||
self._iphlpapi.GetIpForwardTable.side_effect = (
|
||||
self._winutils.ERROR_INSUFFICIENT_BUFFER, None)
|
||||
|
||||
self._test__get_forward_table(reallocation=True,
|
||||
insufficient_buffer=True)
|
||||
|
||||
def test__get_forward_table_insufficient_buffer_other_error(self):
|
||||
self._kernel32.HeapAlloc.side_effect = (
|
||||
mock.sentinel.table_mem, mock.sentinel.table_mem)
|
||||
self._iphlpapi.GetIpForwardTable.side_effect = (
|
||||
self._winutils.ERROR_INSUFFICIENT_BUFFER, mock.sentinel.error)
|
||||
|
||||
self._test__get_forward_table(reallocation=True,
|
||||
insufficient_buffer=True,
|
||||
fail=True)
|
||||
|
||||
@mock.patch('cloudbaseinit.osutils.windows.WindowsUtils.'
|
||||
'_get_forward_table')
|
||||
def test_routes(self, mock_forward_table):
|
||||
def _same(arg):
|
||||
return arg._mock_name.encode()
|
||||
|
||||
route = mock.MagicMock()
|
||||
mock_cast_result = mock.Mock()
|
||||
mock_cast_result.contents = [route]
|
||||
self._ctypes_mock.cast.return_value = mock_cast_result
|
||||
self.windows_utils.Ws2_32.inet_ntoa.side_effect = _same
|
||||
route.dwForwardIfIndex = 'dwForwardIfIndex'
|
||||
route.dwForwardProto = 'dwForwardProto'
|
||||
route.dwForwardMetric1 = 'dwForwardMetric1'
|
||||
routes = self._winutils._get_ipv4_routing_table()
|
||||
|
||||
mock_forward_table.assert_called_once_with()
|
||||
enter = mock_forward_table.return_value.__enter__
|
||||
enter.assert_called_once_with()
|
||||
exit_ = mock_forward_table.return_value.__exit__
|
||||
exit_.assert_called_once_with(None, None, None)
|
||||
self.assertEqual(1, len(routes))
|
||||
given_route = routes[0]
|
||||
self.assertEqual('dwForwardDest', given_route[0])
|
||||
self.assertEqual('dwForwardMask', given_route[1])
|
||||
self.assertEqual('dwForwardNextHop', given_route[2])
|
||||
self.assertEqual('dwForwardIfIndex', given_route[3])
|
||||
self.assertEqual('dwForwardMetric1', given_route[4])
|
||||
|
Loading…
x
Reference in New Issue
Block a user