diff --git a/requests_mock/mocker.py b/requests_mock/mocker.py index bb9e1e5..80a367f 100644 --- a/requests_mock/mocker.py +++ b/requests_mock/mocker.py @@ -25,6 +25,8 @@ PATCH = 'PATCH' POST = 'POST' PUT = 'PUT' +_original_send = requests.Session.send + class MockerCore(object): """A wrapper around common mocking functions. @@ -63,7 +65,7 @@ class MockerCore(object): self._adapter = adapter.Adapter(case_sensitive=case_sensitive) self._real_http = kwargs.pop('real_http', False) - self._real_send = None + self._last_send = None if kwargs: raise TypeError('Unexpected Arguments: %s' % ', '.join(kwargs)) @@ -73,10 +75,10 @@ class MockerCore(object): Install the adapter and the wrappers required to intercept requests. """ - if self._real_send: + if self._last_send: raise RuntimeError('Mocker has already been started') - self._real_send = requests.Session.send + self._last_send = requests.Session.send def _fake_get_adapter(session, url): return self._adapter @@ -85,8 +87,18 @@ class MockerCore(object): real_get_adapter = requests.Session.get_adapter requests.Session.get_adapter = _fake_get_adapter + # NOTE(jamielennox): self._last_send vs _original_send. Whilst it + # seems like here we would use _last_send there is the possibility + # that the user has messed up and is somehow nesting their mockers. + # If we call last_send at this point then we end up calling this + # function again and the outer level adapter ends up winning. + # All we really care about here is that our adapter is in place + # before calling send so we always jump directly to the real + # function so that our most recently patched send call ends up + # putting in the most recent adapter. It feels funny, but it works. + try: - return self._real_send(session, request, **kwargs) + return _original_send(session, request, **kwargs) except exceptions.NoMockAddress: if not self._real_http: raise @@ -97,7 +109,7 @@ class MockerCore(object): finally: requests.Session.get_adapter = real_get_adapter - return self._real_send(session, request, **kwargs) + return _original_send(session, request, **kwargs) requests.Session.send = _fake_send @@ -106,9 +118,9 @@ class MockerCore(object): This should have no impact if mocking has not been started. """ - if self._real_send: - requests.Session.send = self._real_send - self._real_send = None + if self._last_send: + requests.Session.send = self._last_send + self._last_send = None def __getattr__(self, name): if name in self._PROXY_FUNCS: diff --git a/requests_mock/tests/test_mocker.py b/requests_mock/tests/test_mocker.py index c119b10..05fa5bc 100644 --- a/requests_mock/tests/test_mocker.py +++ b/requests_mock/tests/test_mocker.py @@ -339,3 +339,73 @@ class MockerHttpMethodsTests(base.TestCase): for k, v in query.items(): self.assertEqual([v], m.last_request.qs[k]) + + def test_nested_mocking(self): + url1 = 'http://url1.com/path1' + url2 = 'http://url2.com/path2' + url3 = 'http://url3.com/path3' + + data1 = 'data1' + data2 = 'data2' + data3 = 'data3' + + with requests_mock.mock() as m1: + + r1 = m1.get(url1, text=data1) + + resp1a = requests.get(url1) + self.assertRaises(exceptions.NoMockAddress, requests.get, url2) + self.assertRaises(exceptions.NoMockAddress, requests.get, url3) + + self.assertEqual(data1, resp1a.text) + + # call count = 3 because there are 3 calls above, url 1-3 + self.assertEqual(3, m1.call_count) + self.assertEqual(1, r1.call_count) + + with requests_mock.mock() as m2: + + r2 = m2.get(url2, text=data2) + + self.assertRaises(exceptions.NoMockAddress, requests.get, url1) + resp2a = requests.get(url2) + self.assertRaises(exceptions.NoMockAddress, requests.get, url3) + + self.assertEqual(data2, resp2a.text) + + with requests_mock.mock() as m3: + + r3 = m3.get(url3, text=data3) + + self.assertRaises(exceptions.NoMockAddress, + requests.get, + url1) + self.assertRaises(exceptions.NoMockAddress, + requests.get, + url2) + resp3 = requests.get(url3) + + self.assertEqual(data3, resp3.text) + + self.assertEqual(3, m3.call_count) + self.assertEqual(1, r3.call_count) + + resp2b = requests.get(url2) + self.assertRaises(exceptions.NoMockAddress, requests.get, url1) + self.assertEqual(data2, resp2b.text) + self.assertRaises(exceptions.NoMockAddress, requests.get, url3) + + self.assertEqual(3, m1.call_count) + self.assertEqual(1, r1.call_count) + self.assertEqual(6, m2.call_count) + self.assertEqual(2, r2.call_count) + self.assertEqual(3, m3.call_count) + self.assertEqual(1, r3.call_count) + + resp1b = requests.get(url1) + self.assertEqual(data1, resp1b.text) + self.assertRaises(exceptions.NoMockAddress, requests.get, url2) + self.assertRaises(exceptions.NoMockAddress, requests.get, url3) + + self.assertEqual(6, m1.call_count) + self.assertEqual(2, r1.call_count)