diff --git a/requests_mock/adapter.py b/requests_mock/adapter.py index 887af31..76ca647 100644 --- a/requests_mock/adapter.py +++ b/requests_mock/adapter.py @@ -14,20 +14,26 @@ import json as jsonutils from requests.adapters import BaseAdapter, HTTPAdapter from requests.packages.urllib3.response import HTTPResponse -from six import BytesIO +import six from six.moves.urllib import parse as urlparse from requests_mock import exceptions -class Context(object): +class _Context(object): + """Stores the data being used to process a current URL match.""" def __init__(self, request, headers, status_code): self.request = request - self.headers = headers.copy() + self.headers = headers self.status_code = status_code def call(self, f, *args, **kwargs): + """Test and call a callback if one was provided. + + If the object provided is callable, then call it and return otherwise + just return the object. + """ if callable(f): status_code, headers, data = f(self.request, *args, **kwargs) if status_code: @@ -39,40 +45,59 @@ class Context(object): return f -class Matcher(object): +class _MatcherResponse(object): - _http_adapter = HTTPAdapter() + def __init__(self, **kwargs): + """ + :param int status_code: The status code to return upon a successful + match. Defaults to 200. + :param HTTPResponse raw: A HTTPResponse object to return upon a + successful match. + :param io.IOBase body: An IO object with a read() method that can + return a body on successful match. + :param bytes content: A byte string to return upon a successful match. + :param unicode text: A text string to return upon a successful match. + :param object json: A python object to be converted to a JSON string + and returned upon a successful match. + :param dict headers: A dictionary object containing headers that are + returned upon a successful match. + """ + self.status_code = kwargs.pop('status_code', 200) + self.raw = kwargs.pop('raw', None) + self.body = kwargs.pop('body', None) + self.content = kwargs.pop('content', None) + self.text = kwargs.pop('text', None) + self.json = kwargs.pop('json', None) + self.headers = kwargs.pop('headers', {}) - def __init__(self, - method, - url, - status_code=200, - raw=None, - body=None, - content=None, - text=None, - json=None, - headers=None): - self.method = method - self.url = urlparse.urlsplit(url.lower()) - self.status_code = status_code - self.raw = raw - self.body = body - self.content = content - self.text = text - self.json = json - self.headers = headers or {} + if kwargs: + raise TypeError('Too many arguments provided to _MatcherResponse') - if sum([bool(x) for x in (self.raw, - self.body, - self.content, - self.text, - self.json)]) > 1: + # mutual exclusion, only 1 body method may be provided + provided = sum([bool(x) for x in (self.raw, + self.body, + self.content, + self.text, + self.json)]) + + if provided == 0: + self.body = six.BytesIO(six.b('')) + elif provided > 1: raise RuntimeError('You may only supply one body element') - def _get_response(self, request): + # whilst in general you shouldn't do type checking in python this + # makes sure we don't end up with differences between the way types + # are handled between python 2 and 3. + if self.content and not (callable(self.content) or + isinstance(self.content, six.binary_type)): + raise TypeError('Content should be a callback or binary data') + if self.text and not (callable(self.text) or + isinstance(self.text, six.string_types)): + raise TypeError('Text should be a callback or string data') + + def get_response(self, request): encoding = None - context = Context(request, self.headers, self.status_code) + context = _Context(request, self.headers.copy(), self.status_code) content = self.content text = self.text @@ -88,7 +113,7 @@ class Matcher(object): content = data.encode(encoding) if content: data = context.call(content) - body = BytesIO(data) + body = six.BytesIO(data) if body: data = context.call(body) raw = HTTPResponse(status=context.status_code, @@ -99,8 +124,31 @@ class Matcher(object): return encoding, raw + +class _Matcher(object): + """Contains all the information about a provided URL to match.""" + + _http_adapter = HTTPAdapter() + + def __init__(self, method, url, responses, complete_qs): + """ + :param bool complete_qs: Match the entire query string. By default URLs + match if all the provided matcher query arguments are matched and + extra query arguments are ignored. Set complete_qs to true to + require that the entire query string needs to match. + """ + self.method = method + self.url = urlparse.urlsplit(url.lower()) + self.responses = responses + self.complete_qs = complete_qs + def create_response(self, request): - encoding, response = self._get_response(request) + if len(self.responses) > 1: + response_matcher = self.responses.pop(0) + else: + response_matcher = self.responses[0] + + encoding, response = response_matcher.get_response(request) req_resp = self._http_adapter.build_response(request, response) req_resp.connection = self req_resp.encoding = encoding @@ -109,36 +157,51 @@ class Matcher(object): def close(self): pass + def match(self, request): + if request.method.lower() != self.method.lower(): + return False -def match(request, matcher): - if request.method.lower() != matcher.method.lower(): - return False + url = urlparse.urlsplit(request.url.lower()) - url = urlparse.urlsplit(request.url.lower()) + if self.url.scheme and url.scheme != self.url.scheme: + return False - if matcher.url.scheme and url.scheme != matcher.url.scheme: - return False + if self.url.netloc and url.netloc != self.url.netloc: + return False - if matcher.url.netloc and url.netloc != matcher.url.netloc: - return False + if (url.path or '/') != (self.url.path or '/'): + return False - if matcher.url.path and url.path != matcher.url.path: - return False + matcher_qs = urlparse.parse_qs(self.url.query) + request_qs = urlparse.parse_qs(url.query) - return True + for k, vals in six.iteritems(matcher_qs): + for v in vals: + try: + request_qs.get(k, []).remove(v) + except ValueError: + return False + + if self.complete_qs: + for v in six.itervalues(request_qs): + if v: + return False + + return True class Adapter(BaseAdapter): + """A fake adapter than can return predefined responses. + """ def __init__(self): self._matchers = [] - self._request_history = [] + self.request_history = [] def send(self, request, **kwargs): - self._request_history.append(request) - for matcher in self._matchers: - if match(request, matcher): + if matcher.match(request): + self.request_history.append(request) return matcher.create_response(request) raise exceptions.NoMockAddress(request) @@ -146,13 +209,31 @@ class Adapter(BaseAdapter): def close(self): pass - def register_uri(self, *args, **kwargs): - self._matchers.append(Matcher(*args, **kwargs)) + def register_uri(self, method, url, request_list=None, **kwargs): + """Register a new URI match and fake response. + :param str method: The HTTP method to match. + :param str url: The URL to match. + """ + complete_qs = kwargs.pop('complete_qs', False) + + if request_list and kwargs: + raise RuntimeError('You should specify either a list of ' + 'responses OR response kwargs. Not both.') + elif not request_list: + request_list = [kwargs] + + responses = [_MatcherResponse(**k) for k in request_list] + self._matchers.append(_Matcher(method, + url, + responses, + complete_qs=complete_qs)) + + @property def last_request(self): """Retrieve the latest request sent""" try: - return self._request_history[-1] + return self.request_history[-1] except IndexError: return None diff --git a/requests_mock/tests/test_adapter.py b/requests_mock/tests/test_adapter.py index 6e9e65f..014770f 100644 --- a/requests_mock/tests/test_adapter.py +++ b/requests_mock/tests/test_adapter.py @@ -36,20 +36,19 @@ class SessionAdapterTests(base.TestCase): self.assertEqual(v, resp.headers[k]) def assertLastRequest(self, method='GET', body=None): - last = self.adapter.last_request() - self.assertEqual(self.url, last.url) - self.assertEqual(method, last.method) - self.assertEqual(body, last.body) + self.assertEqual(self.url, self.adapter.last_request.url) + self.assertEqual(method, self.adapter.last_request.method) + self.assertEqual(body, self.adapter.last_request.body) def test_content(self): - data = 'testdata' + data = six.b('testdata') self.adapter.register_uri('GET', self.url, content=data, headers=self.headers) resp = self.session.get(self.url) - self.assertEqual(six.b(data), resp.content) + self.assertEqual(data, resp.content) self.assertHeaders(resp) self.assertLastRequest() @@ -131,6 +130,12 @@ class SessionAdapterTests(base.TestCase): self.assertHeaders(resp) self.assertLastRequest() + def test_no_body(self): + self.adapter.register_uri('GET', self.url) + resp = self.session.get(self.url) + self.assertEqual(six.b(''), resp.content) + self.assertEqual(200, resp.status_code) + def test_multiple_body_elements(self): self.assertRaises(RuntimeError, self.adapter.register_uri, @@ -138,3 +143,19 @@ class SessionAdapterTests(base.TestCase): 'GET', content=six.b('b'), text=six.u('u')) + + def test_multiple_responses(self): + inp = [{'status_code': 400, 'text': 'abcd'}, + {'status_code': 300, 'text': 'defg'}, + {'status_code': 200, 'text': 'hijk'}] + + self.adapter.register_uri('GET', self.url, inp) + out = [self.session.get(self.url) for i in range(0, len(inp))] + + for i, o in zip(inp, out): + for k, v in six.iteritems(i): + self.assertEqual(v, getattr(o, k)) + + last = self.session.get(self.url) + for k, v in six.iteritems(inp[-1]): + self.assertEqual(v, getattr(last, k)) diff --git a/requests_mock/tests/test_fixture.py b/requests_mock/tests/test_fixture.py index 36fc7a4..8a9f5c8 100644 --- a/requests_mock/tests/test_fixture.py +++ b/requests_mock/tests/test_fixture.py @@ -36,4 +36,4 @@ class MockingTests(base.TestCase): resp = requests.get(test_url) self.assertEqual('response', resp.text) - self.assertEqual(test_url, self.mocker.last_request().url) + self.assertEqual(test_url, self.mocker.last_request.url) diff --git a/requests_mock/tests/test_matcher.py b/requests_mock/tests/test_matcher.py new file mode 100644 index 0000000..34b2198 --- /dev/null +++ b/requests_mock/tests/test_matcher.py @@ -0,0 +1,85 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import requests + +from requests_mock import adapter +from requests_mock.tests import base + + +class TestMatcher(base.TestCase): + + def match(self, target, url, complete_qs=False): + matcher = adapter._Matcher('GET', target, [], complete_qs) + request = requests.Request('GET', url).prepare() + return matcher.match(request) + + def assertMatch(self, target, url, **kwargs): + self.assertEqual(True, self.match(target, url, **kwargs), + 'Matcher %s failed to match %s' % (target, url)) + + def assertMatchBoth(self, target, url, **kwargs): + self.assertMatch(target, url, **kwargs) + self.assertMatch(url, target, **kwargs) + + def assertNoMatch(self, target, url, **kwargs): + self.assertEqual(False, self.match(target, url, **kwargs), + 'Matcher %s unexpectedly matched %s' % (target, url)) + + def assertNoMatchBoth(self, target, url, **kwargs): + self.assertNoMatch(target, url, **kwargs) + self.assertNoMatch(url, target, **kwargs) + + def test_url_matching(self): + self.assertMatchBoth('http://www.test.com', + 'http://www.test.com') + self.assertMatchBoth('http://www.test.com', + 'http://www.test.com/') + self.assertMatchBoth('http://www.test.com/abc', + 'http://www.test.com/abc') + self.assertMatchBoth('http://www.test.com:5000/abc', + 'http://www.test.com:5000/abc') + + self.assertNoMatchBoth('https://www.test.com', + 'http://www.test.com') + self.assertNoMatchBoth('http://www.test.com/abc', + 'http://www.test.com') + self.assertNoMatchBoth('http://test.com', + 'http://www.test.com') + self.assertNoMatchBoth('http://test.com', + 'http://www.test.com') + self.assertNoMatchBoth('http://test.com/abc', + 'http://www.test.com/abc/') + self.assertNoMatchBoth('http://test.com/abc/', + 'http://www.test.com/abc') + self.assertNoMatchBoth('http://test.com:5000/abc/', + 'http://www.test.com/abc') + self.assertNoMatchBoth('http://test.com/abc/', + 'http://www.test.com:5000/abc') + + def test_subset_match(self): + self.assertMatch('/path', 'http://www.test.com/path') + self.assertMatch('/path', 'http://www.test.com/path') + self.assertMatch('//www.test.com/path', 'http://www.test.com/path') + self.assertMatch('//www.test.com/path', 'https://www.test.com/path') + + def test_query_string(self): + self.assertMatch('/path?a=1&b=2', + 'http://www.test.com/path?a=1&b=2') + self.assertMatch('/path?a=1', + 'http://www.test.com/path?a=1&b=2', + complete_qs=False) + self.assertNoMatch('/path?a=1', + 'http://www.test.com/path?a=1&b=2', + complete_qs=True) + self.assertNoMatch('/path?a=1&b=2', + 'http://www.test.com/path?a=1')