Fixes in the course of using library in practice

- Most recent register_uri should win.
- Test for responses should check None as '' is legitimate text.
- More tests, partially coverage driven.
This commit is contained in:
Jamie Lennox 2014-06-16 10:50:42 +10:00
parent dba1cdf98c
commit 074f9d3dec
6 changed files with 252 additions and 40 deletions

View File

@ -47,6 +47,8 @@ class _Context(object):
class _MatcherResponse(object):
_BODY_ARGS = ['raw', 'body', 'content', 'text', 'json']
def __init__(self, **kwargs):
"""
:param int status_code: The status code to return upon a successful
@ -62,28 +64,27 @@ class _MatcherResponse(object):
:param dict headers: A dictionary object containing headers that are
returned upon a successful match.
"""
# mutual exclusion, only 1 body method may be provided
provided = [x for x in self._BODY_ARGS if kwargs.get(x) is not None]
self.status_code = kwargs.pop('status_code', 200)
self.raw = kwargs.pop('raw', None)
self.body = kwargs.pop('body', None)
self.content = kwargs.pop('content', None)
self.text = kwargs.pop('text', None)
self.json = kwargs.pop('json', None)
self.reason = kwargs.pop('reason', None)
self.headers = kwargs.pop('headers', {})
if kwargs:
raise TypeError('Too many arguments provided to _MatcherResponse')
raise TypeError('Too many arguments provided. Unexpected '
'arguments %s' % ', '.join(kwargs.keys()))
# mutual exclusion, only 1 body method may be provided
provided = sum([bool(x) for x in (self.raw,
self.body,
self.content,
self.text,
self.json)])
if provided == 0:
if len(provided) == 0:
self.body = six.BytesIO(six.b(''))
elif provided > 1:
raise RuntimeError('You may only supply one body element')
elif len(provided) > 1:
raise RuntimeError('You may only supply one body element. You '
'supplied %s' % ', '.join(provided))
# whilst in general you shouldn't do type checking in python this
# makes sure we don't end up with differences between the way types
@ -104,21 +105,22 @@ class _MatcherResponse(object):
body = self.body
raw = self.raw
if self.json:
if self.json is not None:
data = context.call(self.json)
text = jsonutils.dumps(data)
if text:
if text is not None:
data = context.call(text)
encoding = 'utf-8'
content = data.encode(encoding)
if content:
if content is not None:
data = context.call(content)
body = six.BytesIO(data)
if body:
if body is not None:
data = context.call(body)
raw = HTTPResponse(status=context.status_code,
body=data,
headers=context.headers,
reason=self.reason,
decode_content=False,
preload_content=False)
@ -138,7 +140,7 @@ class _Matcher(object):
require that the entire query string needs to match.
"""
self.method = method
self.url = urlparse.urlsplit(url.lower())
self.url = urlparse.urlparse(url.lower())
self.responses = responses
self.complete_qs = complete_qs
@ -161,7 +163,7 @@ class _Matcher(object):
if request.method.lower() != self.method.lower():
return False
url = urlparse.urlsplit(request.url.lower())
url = urlparse.urlparse(request.url.lower())
if self.url.scheme and url.scheme != self.url.scheme:
return False
@ -199,7 +201,7 @@ class Adapter(BaseAdapter):
self.request_history = []
def send(self, request, **kwargs):
for matcher in self._matchers:
for matcher in reversed(self._matchers):
if matcher.match(request):
self.request_history.append(request)
return matcher.create_response(request)
@ -209,7 +211,7 @@ class Adapter(BaseAdapter):
def close(self):
pass
def register_uri(self, method, url, request_list=None, **kwargs):
def register_uri(self, method, url, response_list=None, **kwargs):
"""Register a new URI match and fake response.
:param str method: The HTTP method to match.
@ -217,13 +219,13 @@ class Adapter(BaseAdapter):
"""
complete_qs = kwargs.pop('complete_qs', False)
if request_list and kwargs:
if response_list and kwargs:
raise RuntimeError('You should specify either a list of '
'responses OR response kwargs. Not both.')
elif not request_list:
request_list = [kwargs]
elif not response_list:
response_list = [kwargs]
responses = [_MatcherResponse(**k) for k in request_list]
responses = [_MatcherResponse(**k) for k in response_list]
self._matchers.append(_Matcher(method,
url,
responses,

View File

@ -20,3 +20,7 @@ class NoMockAddress(MockException):
def __init__(self, request):
self.request = request
def __str__(self):
return "No mock address: %s %s" % (self.request.method,
self.request.url)

View File

@ -19,7 +19,8 @@ from requests_mock import adapter
class Fixture(fixtures.Fixture):
PROXY_FUNCS = set(['last_request',
'register_uri'])
'register_uri',
'request_history'])
def __init__(self):
super(Fixture, self).__init__()
@ -51,4 +52,4 @@ class Fixture(fixtures.Fixture):
except AttributeError:
pass
return super(Fixture, self)._getattr__(name)
raise AttributeError(name)

View File

@ -159,3 +159,126 @@ class SessionAdapterTests(base.TestCase):
last = self.session.get(self.url)
for k, v in six.iteritems(inp[-1]):
self.assertEqual(v, getattr(last, k))
def test_callback_optional_status(self):
headers = {'a': 'b'}
def _test_cb(request):
return None, headers, ''
self.adapter.register_uri('GET',
self.url,
text=_test_cb,
status_code=300)
resp = self.session.get(self.url)
self.assertEqual(300, resp.status_code)
for k, v in six.iteritems(headers):
self.assertEqual(v, resp.headers[k])
def test_callback_optional_headers(self):
headers = {'a': 'b'}
def _test_cb(request):
return 300, None, ''
self.adapter.register_uri('GET',
self.url,
text=_test_cb,
headers=headers)
resp = self.session.get(self.url)
self.assertEqual(300, resp.status_code)
for k, v in six.iteritems(headers):
self.assertEqual(v, resp.headers[k])
def test_callback_adds_headers(self):
headers_a = {'a': 'b'}
headers_b = {'c': 'd'}
def _test_cb(request):
return 200, headers_b, ''
self.adapter.register_uri('GET',
self.url,
text=_test_cb,
headers=headers_a)
resp = self.session.get(self.url)
self.assertEqual(200, resp.status_code)
for headers in (headers_a, headers_b):
for k, v in six.iteritems(headers):
self.assertEqual(v, resp.headers[k])
def test_latest_register_overrides(self):
self.adapter.register_uri('GET', self.url, text='abc')
self.adapter.register_uri('GET', self.url, text='def')
resp = self.session.get(self.url)
self.assertEqual('def', resp.text)
def test_no_last_request(self):
self.assertIsNone(self.adapter.last_request)
self.assertEqual(0, len(self.adapter.request_history))
def test_dont_pass_list_and_kwargs(self):
self.assertRaises(RuntimeError,
self.adapter.register_uri,
'GET',
self.url,
[{'text': 'a'}],
headers={'a': 'b'})
def test_empty_string_return(self):
# '' evaluates as False, so make sure an empty string is not ignored.
self.adapter.register_uri('GET', self.url, text='')
resp = self.session.get(self.url)
self.assertEqual('', resp.text)
def test_dont_pass_multiple_bodies(self):
self.assertRaises(RuntimeError,
self.adapter.register_uri,
'GET',
self.url,
json={'abc': 'def'},
text='ghi')
def test_dont_pass_unexpected_kwargs(self):
self.assertRaises(TypeError,
self.adapter.register_uri,
'GET',
self.url,
unknown='argument')
def test_dont_pass_unicode_as_content(self):
self.assertRaises(TypeError,
self.adapter.register_uri,
'GET',
self.url,
content=six.u('unicode'))
def test_dont_pass_bytes_as_text(self):
if six.PY2:
self.skipTest('Cannot enforce byte behaviour in PY2')
self.assertRaises(TypeError,
self.adapter.register_uri,
'GET',
self.url,
text=six.b('bytes'))
def test_dont_pass_non_str_as_content(self):
self.assertRaises(TypeError,
self.adapter.register_uri,
'GET',
self.url,
content=5)
def test_dont_pass_non_str_as_text(self):
self.assertRaises(TypeError,
self.adapter.register_uri,
'GET',
self.url,
text=5)

View File

@ -37,3 +37,6 @@ class MockingTests(base.TestCase):
resp = requests.get(test_url)
self.assertEqual('response', resp.text)
self.assertEqual(test_url, self.mocker.last_request.url)
def test_fixture_has_normal_attr_error(self):
self.assertRaises(AttributeError, lambda: self.mocker.unknown)

View File

@ -18,26 +18,100 @@ from requests_mock.tests import base
class TestMatcher(base.TestCase):
def match(self, target, url, complete_qs=False):
matcher = adapter._Matcher('GET', target, [], complete_qs)
request = requests.Request('GET', url).prepare()
def match(self,
target,
url,
matcher_method='GET',
request_method='GET',
complete_qs=False):
matcher = adapter._Matcher(matcher_method, target, [], complete_qs)
request = requests.Request(request_method, url).prepare()
return matcher.match(request)
def assertMatch(self, target, url, **kwargs):
self.assertEqual(True, self.match(target, url, **kwargs),
'Matcher %s failed to match %s' % (target, url))
def assertMatch(self,
target,
url,
matcher_method='GET',
request_method='GET',
**kwargs):
self.assertEqual(True,
self.match(target,
url,
matcher_method=matcher_method,
request_method=request_method,
**kwargs),
'Matcher %s %s failed to match %s %s' %
(matcher_method, target, request_method, url))
def assertMatchBoth(self, target, url, **kwargs):
self.assertMatch(target, url, **kwargs)
self.assertMatch(url, target, **kwargs)
def assertMatchBoth(self,
target,
url,
matcher_method='GET',
request_method='GET',
**kwargs):
self.assertMatch(target,
url,
matcher_method=matcher_method,
request_method=request_method,
**kwargs)
self.assertMatch(url,
target,
matcher_method=request_method,
request_method=matcher_method,
**kwargs)
def assertNoMatch(self, target, url, **kwargs):
self.assertEqual(False, self.match(target, url, **kwargs),
'Matcher %s unexpectedly matched %s' % (target, url))
def assertNoMatch(self,
target,
url,
matcher_method='GET',
request_method='GET',
**kwargs):
self.assertEqual(False,
self.match(target,
url,
matcher_method=matcher_method,
request_method=request_method,
**kwargs),
'Matcher %s %s unexpectedly matched %s %s' %
(matcher_method, target, request_method, url))
def assertNoMatchBoth(self, target, url, **kwargs):
self.assertNoMatch(target, url, **kwargs)
self.assertNoMatch(url, target, **kwargs)
def assertNoMatchBoth(self,
target,
url,
matcher_method='GET',
request_method='GET',
**kwargs):
self.assertNoMatch(target,
url,
matcher_method=matcher_method,
request_method=request_method,
**kwargs)
self.assertNoMatch(url,
target,
matcher_method=request_method,
request_method=matcher_method,
**kwargs)
def assertMatchMethodBoth(self, matcher_method, request_method, **kwargs):
url = 'http://www.test.com'
self.assertMatchBoth(url,
url,
request_method=request_method,
matcher_method=matcher_method,
**kwargs)
def assertNoMatchMethodBoth(self,
matcher_method,
request_method,
**kwargs):
url = 'http://www.test.com'
self.assertNoMatchBoth(url,
url,
request_method=request_method,
matcher_method=matcher_method,
**kwargs)
def test_url_matching(self):
self.assertMatchBoth('http://www.test.com',
@ -83,3 +157,8 @@ class TestMatcher(base.TestCase):
complete_qs=True)
self.assertNoMatch('/path?a=1&b=2',
'http://www.test.com/path?a=1')
def test_method_match(self):
self.assertNoMatchMethodBoth('GET', 'POST')
self.assertMatchMethodBoth('GET', 'get')
self.assertMatchMethodBoth('GeT', 'geT')