import unittest from unittest.mock import MagicMock, patch, ANY import serial import socket class ConnectTestCase(unittest.TestCase): """ Class to test connect method """ @patch("serial.LOG.info") @patch("socket.socket") def test_connect_windows(self, mock_socket, mock_log_info): """ Test connect method """ # Setup mock_socket.return_value = mock_socket hostname = 'hostname' port = 10000 # Run result = serial.connect(hostname, port) # Assert mock_socket.assert_called_once_with(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) mock_socket.connect.assert_called_once_with(('localhost', port)) self.assertEqual(result, mock_socket) @patch("serial.LOG.info") @patch("socket.socket") def test_connect_fail(self, mock_socket, mock_log_info): """ Test connect method when connection fails """ # Setup mock_socket.return_value = mock_socket hostname = 'hostname' port = 10000 mock_socket.connect.side_effect = Exception # Run result = serial.connect(hostname, port) # Assert mock_socket.assert_called_once_with(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) mock_socket.connect.assert_called_once_with(('localhost', port)) mock_log_info.assert_called_with("Failed sock connection") self.assertIsNone(result) class DisconnectTestCase(unittest.TestCase): """ Class to test disconnect method """ @patch("serial.LOG.info") def test_disconnect(self, mock_log_info): """ Test disconnect method """ # Setup sock = MagicMock() # Run serial.disconnect(sock) # Assert sock.shutdown.assert_called_once_with(socket.SHUT_RDWR) sock.close.assert_called_once() mock_log_info.assert_any_call(ANY) # TODO This test is just for coverage purposes, this function needs a heavy refactoring class GetOutputTestCase(unittest.TestCase): """ Class to test get_output method """ @patch("serial.LOG.info") @patch("serial.time") def test_get_output(self, mock_time, mock_log_info): """ Test get_output method """ # Setup stream = MagicMock() stream.poll.return_value = None stream.gettimeout.return_value = 1 stream.recv.side_effect = ['cmd\n', 'test\n', ':~$ '] mock_time.time.side_effect = [0, 1, 2, 3] cmd = "cmd" prompts = [':~$ ', ':~# ', ':/home/wrsroot# ', '(keystone_.*)]$ ', '(keystone_.*)]# '] timeout = 2 log = True as_lines = True flush = True # Run with self.assertRaises(Exception): serial.get_output(stream, cmd, prompts, timeout, log, as_lines, flush) # Assert stream.sendall.assert_called_once_with(f"{cmd}\n".encode('utf-8')) mock_log_info.assert_any_call('cmd') mock_log_info.assert_any_call('test') class ExpectBytesTestCase(unittest.TestCase): """ Class to test expect_bytes method """ @patch("serial.LOG.debug") @patch("serial.LOG.info") @patch("serial.stdout.write") def test_expect_bytes(self, mock_stdout_write, mock_log_info, mock_log_debug): """ Test expect_bytes method """ # Setup stream = MagicMock() stream.expect_bytes.return_value = None stream.poll.return_value = None text = "Hello, world!" timeout = 180 fail_ok = False flush = True # Run result = serial.expect_bytes(stream, text, timeout, fail_ok, flush) # Assert self.assertEqual(result, 0) stream.expect_bytes.assert_called_once_with(f"{text}".encode('utf-8'), timeout=timeout) mock_stdout_write.assert_any_call('\n') mock_log_info.assert_any_call("Expecting text within %s minutes: %s\n", timeout / 60, text) mock_log_debug.assert_any_call("Found expected text: %s", text) class SendBytesTestCase(unittest.TestCase): """ Class to test send_bytes method """ @patch("serial.LOG.info") @patch("serial.expect_bytes") def test_send_bytes(self, mock_expect_bytes, mock_log_info): """ Test send_bytes method """ # Setup stream = MagicMock() stream.poll.return_value = None text = "Hello, world!" fail_ok = False expect_prompt = True prompt = None timeout = 180 send = True flush = True mock_expect_bytes.return_value = 0 # Run result = serial.send_bytes(stream, text, fail_ok, expect_prompt, prompt, timeout, send, flush) # Assert self.assertEqual(result, 0) mock_expect_bytes.assert_called() stream.sendall.assert_called_once_with(f"{text}\n".encode('utf-8')) if __name__ == '__main__': unittest.main()