Add request_headers matching

If you provide request_headers then the headers must also match the
requeest to match.
This commit is contained in:
Jamie Lennox 2014-06-23 20:28:50 +10:00
parent 833ee180c3
commit 595deb3dfd
3 changed files with 66 additions and 6 deletions

View File

@ -93,3 +93,20 @@ The URL is then matched using :py:meth:`re.regex.search` which means that it wil
'resp'
If you use regular expression matching then *requests-mock* can't do it's normal query string or path only matching, that will need to be part of the expression.
Request Headers
===============
A dictionary of headers can be supplied such that the request will only match if the available headers also match.
Only the headers that are provided need match, any additional headers will be ignored.
.. code:: python
>>> adapter.register_uri('POST', 'mock://test.com', request_headers={'key', 'val'}, text='resp')
>>> session.post('mock://test.com', headers={'key': 'val', 'another': 'header'}).text
'resp'
>>> session.post('mock://test.com')
Traceback (most recent call last):
...
requests_mock.exceptions.NoMockAddress: No mock address: POST mock://test.com/

View File

@ -124,7 +124,7 @@ class _Matcher(object):
_http_adapter = HTTPAdapter()
def __init__(self, method, url, responses, complete_qs):
def __init__(self, method, url, responses, complete_qs, request_headers):
"""
:param bool complete_qs: Match the entire query string. By default URLs
match if all the provided matcher query arguments are matched and
@ -139,6 +139,7 @@ class _Matcher(object):
self.url_parts = None
self.responses = responses
self.complete_qs = complete_qs
self.request_headers = request_headers
def create_response(self, request):
if len(self.responses) > 1:
@ -200,8 +201,22 @@ class _Matcher(object):
return True
def _match_headers(self, request):
for k, vals in six.iteritems(self.request_headers):
try:
header = request.headers[k]
except KeyError:
return False
else:
if header != vals:
return False
return True
def match(self, request):
return self._match_method(request) and self._match_url(request)
return (self._match_method(request) and
self._match_url(request) and
self._match_headers(request))
class Adapter(BaseAdapter):
@ -231,6 +246,7 @@ class Adapter(BaseAdapter):
:param str url: The URL to match.
"""
complete_qs = kwargs.pop('complete_qs', False)
request_headers = kwargs.pop('request_headers', {})
if response_list and kwargs:
raise RuntimeError('You should specify either a list of '
@ -242,7 +258,8 @@ class Adapter(BaseAdapter):
self._matchers.append(_Matcher(method,
url,
responses,
complete_qs=complete_qs))
complete_qs=complete_qs,
request_headers=request_headers))
@property
def last_request(self):

View File

@ -27,9 +27,15 @@ class TestMatcher(base.TestCase):
url,
matcher_method='GET',
request_method='GET',
complete_qs=False):
matcher = adapter._Matcher(matcher_method, target, [], complete_qs)
request = requests.Request(request_method, url).prepare()
complete_qs=False,
headers=None,
request_headers={}):
matcher = adapter._Matcher(matcher_method,
target,
[],
complete_qs,
request_headers)
request = requests.Request(request_method, url, headers).prepare()
return matcher.match(request)
def assertMatch(self,
@ -194,3 +200,23 @@ class TestMatcher(base.TestCase):
self.assertMatch(r2, 'http://anything/a/b/c/d')
self.assertMatch(r2, 'mock://anything/a/b/c/d')
def test_match_with_headers(self):
self.assertMatch('/path',
'http://www.test.com/path',
headers={'A': 'abc', 'b': 'def'},
request_headers={'a': 'abc'})
self.assertMatch('/path',
'http://www.test.com/path',
headers={'A': 'abc', 'b': 'def'})
self.assertNoMatch('/path',
'http://www.test.com/path',
headers={'A': 'abc', 'b': 'def'},
request_headers={'b': 'abc'})
self.assertNoMatch('/path',
'http://www.test.com/path',
headers={'A': 'abc', 'b': 'def'},
request_headers={'c': 'ghi'})