More reorganizing

Added response lists, new tests, renamed and changed stuff.
This commit is contained in:
Jamie Lennox 2014-06-15 00:15:32 +10:00
parent e01f313529
commit 81439f4fa2
4 changed files with 245 additions and 58 deletions

View File

@ -14,20 +14,26 @@ import json as jsonutils
from requests.adapters import BaseAdapter, HTTPAdapter
from requests.packages.urllib3.response import HTTPResponse
from six import BytesIO
import six
from six.moves.urllib import parse as urlparse
from requests_mock import exceptions
class Context(object):
class _Context(object):
"""Stores the data being used to process a current URL match."""
def __init__(self, request, headers, status_code):
self.request = request
self.headers = headers.copy()
self.headers = headers
self.status_code = status_code
def call(self, f, *args, **kwargs):
"""Test and call a callback if one was provided.
If the object provided is callable, then call it and return otherwise
just return the object.
"""
if callable(f):
status_code, headers, data = f(self.request, *args, **kwargs)
if status_code:
@ -39,40 +45,59 @@ class Context(object):
return f
class Matcher(object):
class _MatcherResponse(object):
_http_adapter = HTTPAdapter()
def __init__(self, **kwargs):
"""
:param int status_code: The status code to return upon a successful
match. Defaults to 200.
:param HTTPResponse raw: A HTTPResponse object to return upon a
successful match.
:param io.IOBase body: An IO object with a read() method that can
return a body on successful match.
:param bytes content: A byte string to return upon a successful match.
:param unicode text: A text string to return upon a successful match.
:param object json: A python object to be converted to a JSON string
and returned upon a successful match.
:param dict headers: A dictionary object containing headers that are
returned upon a successful match.
"""
self.status_code = kwargs.pop('status_code', 200)
self.raw = kwargs.pop('raw', None)
self.body = kwargs.pop('body', None)
self.content = kwargs.pop('content', None)
self.text = kwargs.pop('text', None)
self.json = kwargs.pop('json', None)
self.headers = kwargs.pop('headers', {})
def __init__(self,
method,
url,
status_code=200,
raw=None,
body=None,
content=None,
text=None,
json=None,
headers=None):
self.method = method
self.url = urlparse.urlsplit(url.lower())
self.status_code = status_code
self.raw = raw
self.body = body
self.content = content
self.text = text
self.json = json
self.headers = headers or {}
if kwargs:
raise TypeError('Too many arguments provided to _MatcherResponse')
if sum([bool(x) for x in (self.raw,
self.body,
self.content,
self.text,
self.json)]) > 1:
# mutual exclusion, only 1 body method may be provided
provided = sum([bool(x) for x in (self.raw,
self.body,
self.content,
self.text,
self.json)])
if provided == 0:
self.body = six.BytesIO(six.b(''))
elif provided > 1:
raise RuntimeError('You may only supply one body element')
def _get_response(self, request):
# whilst in general you shouldn't do type checking in python this
# makes sure we don't end up with differences between the way types
# are handled between python 2 and 3.
if self.content and not (callable(self.content) or
isinstance(self.content, six.binary_type)):
raise TypeError('Content should be a callback or binary data')
if self.text and not (callable(self.text) or
isinstance(self.text, six.string_types)):
raise TypeError('Text should be a callback or string data')
def get_response(self, request):
encoding = None
context = Context(request, self.headers, self.status_code)
context = _Context(request, self.headers.copy(), self.status_code)
content = self.content
text = self.text
@ -88,7 +113,7 @@ class Matcher(object):
content = data.encode(encoding)
if content:
data = context.call(content)
body = BytesIO(data)
body = six.BytesIO(data)
if body:
data = context.call(body)
raw = HTTPResponse(status=context.status_code,
@ -99,8 +124,31 @@ class Matcher(object):
return encoding, raw
class _Matcher(object):
"""Contains all the information about a provided URL to match."""
_http_adapter = HTTPAdapter()
def __init__(self, method, url, responses, complete_qs):
"""
:param bool complete_qs: Match the entire query string. By default URLs
match if all the provided matcher query arguments are matched and
extra query arguments are ignored. Set complete_qs to true to
require that the entire query string needs to match.
"""
self.method = method
self.url = urlparse.urlsplit(url.lower())
self.responses = responses
self.complete_qs = complete_qs
def create_response(self, request):
encoding, response = self._get_response(request)
if len(self.responses) > 1:
response_matcher = self.responses.pop(0)
else:
response_matcher = self.responses[0]
encoding, response = response_matcher.get_response(request)
req_resp = self._http_adapter.build_response(request, response)
req_resp.connection = self
req_resp.encoding = encoding
@ -109,36 +157,51 @@ class Matcher(object):
def close(self):
pass
def match(self, request):
if request.method.lower() != self.method.lower():
return False
def match(request, matcher):
if request.method.lower() != matcher.method.lower():
return False
url = urlparse.urlsplit(request.url.lower())
url = urlparse.urlsplit(request.url.lower())
if self.url.scheme and url.scheme != self.url.scheme:
return False
if matcher.url.scheme and url.scheme != matcher.url.scheme:
return False
if self.url.netloc and url.netloc != self.url.netloc:
return False
if matcher.url.netloc and url.netloc != matcher.url.netloc:
return False
if (url.path or '/') != (self.url.path or '/'):
return False
if matcher.url.path and url.path != matcher.url.path:
return False
matcher_qs = urlparse.parse_qs(self.url.query)
request_qs = urlparse.parse_qs(url.query)
return True
for k, vals in six.iteritems(matcher_qs):
for v in vals:
try:
request_qs.get(k, []).remove(v)
except ValueError:
return False
if self.complete_qs:
for v in six.itervalues(request_qs):
if v:
return False
return True
class Adapter(BaseAdapter):
"""A fake adapter than can return predefined responses.
"""
def __init__(self):
self._matchers = []
self._request_history = []
self.request_history = []
def send(self, request, **kwargs):
self._request_history.append(request)
for matcher in self._matchers:
if match(request, matcher):
if matcher.match(request):
self.request_history.append(request)
return matcher.create_response(request)
raise exceptions.NoMockAddress(request)
@ -146,13 +209,31 @@ class Adapter(BaseAdapter):
def close(self):
pass
def register_uri(self, *args, **kwargs):
self._matchers.append(Matcher(*args, **kwargs))
def register_uri(self, method, url, request_list=None, **kwargs):
"""Register a new URI match and fake response.
:param str method: The HTTP method to match.
:param str url: The URL to match.
"""
complete_qs = kwargs.pop('complete_qs', False)
if request_list and kwargs:
raise RuntimeError('You should specify either a list of '
'responses OR response kwargs. Not both.')
elif not request_list:
request_list = [kwargs]
responses = [_MatcherResponse(**k) for k in request_list]
self._matchers.append(_Matcher(method,
url,
responses,
complete_qs=complete_qs))
@property
def last_request(self):
"""Retrieve the latest request sent"""
try:
return self._request_history[-1]
return self.request_history[-1]
except IndexError:
return None

View File

@ -36,20 +36,19 @@ class SessionAdapterTests(base.TestCase):
self.assertEqual(v, resp.headers[k])
def assertLastRequest(self, method='GET', body=None):
last = self.adapter.last_request()
self.assertEqual(self.url, last.url)
self.assertEqual(method, last.method)
self.assertEqual(body, last.body)
self.assertEqual(self.url, self.adapter.last_request.url)
self.assertEqual(method, self.adapter.last_request.method)
self.assertEqual(body, self.adapter.last_request.body)
def test_content(self):
data = 'testdata'
data = six.b('testdata')
self.adapter.register_uri('GET',
self.url,
content=data,
headers=self.headers)
resp = self.session.get(self.url)
self.assertEqual(six.b(data), resp.content)
self.assertEqual(data, resp.content)
self.assertHeaders(resp)
self.assertLastRequest()
@ -131,6 +130,12 @@ class SessionAdapterTests(base.TestCase):
self.assertHeaders(resp)
self.assertLastRequest()
def test_no_body(self):
self.adapter.register_uri('GET', self.url)
resp = self.session.get(self.url)
self.assertEqual(six.b(''), resp.content)
self.assertEqual(200, resp.status_code)
def test_multiple_body_elements(self):
self.assertRaises(RuntimeError,
self.adapter.register_uri,
@ -138,3 +143,19 @@ class SessionAdapterTests(base.TestCase):
'GET',
content=six.b('b'),
text=six.u('u'))
def test_multiple_responses(self):
inp = [{'status_code': 400, 'text': 'abcd'},
{'status_code': 300, 'text': 'defg'},
{'status_code': 200, 'text': 'hijk'}]
self.adapter.register_uri('GET', self.url, inp)
out = [self.session.get(self.url) for i in range(0, len(inp))]
for i, o in zip(inp, out):
for k, v in six.iteritems(i):
self.assertEqual(v, getattr(o, k))
last = self.session.get(self.url)
for k, v in six.iteritems(inp[-1]):
self.assertEqual(v, getattr(last, k))

View File

@ -36,4 +36,4 @@ class MockingTests(base.TestCase):
resp = requests.get(test_url)
self.assertEqual('response', resp.text)
self.assertEqual(test_url, self.mocker.last_request().url)
self.assertEqual(test_url, self.mocker.last_request.url)

View File

@ -0,0 +1,85 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import requests
from requests_mock import adapter
from requests_mock.tests import base
class TestMatcher(base.TestCase):
def match(self, target, url, complete_qs=False):
matcher = adapter._Matcher('GET', target, [], complete_qs)
request = requests.Request('GET', url).prepare()
return matcher.match(request)
def assertMatch(self, target, url, **kwargs):
self.assertEqual(True, self.match(target, url, **kwargs),
'Matcher %s failed to match %s' % (target, url))
def assertMatchBoth(self, target, url, **kwargs):
self.assertMatch(target, url, **kwargs)
self.assertMatch(url, target, **kwargs)
def assertNoMatch(self, target, url, **kwargs):
self.assertEqual(False, self.match(target, url, **kwargs),
'Matcher %s unexpectedly matched %s' % (target, url))
def assertNoMatchBoth(self, target, url, **kwargs):
self.assertNoMatch(target, url, **kwargs)
self.assertNoMatch(url, target, **kwargs)
def test_url_matching(self):
self.assertMatchBoth('http://www.test.com',
'http://www.test.com')
self.assertMatchBoth('http://www.test.com',
'http://www.test.com/')
self.assertMatchBoth('http://www.test.com/abc',
'http://www.test.com/abc')
self.assertMatchBoth('http://www.test.com:5000/abc',
'http://www.test.com:5000/abc')
self.assertNoMatchBoth('https://www.test.com',
'http://www.test.com')
self.assertNoMatchBoth('http://www.test.com/abc',
'http://www.test.com')
self.assertNoMatchBoth('http://test.com',
'http://www.test.com')
self.assertNoMatchBoth('http://test.com',
'http://www.test.com')
self.assertNoMatchBoth('http://test.com/abc',
'http://www.test.com/abc/')
self.assertNoMatchBoth('http://test.com/abc/',
'http://www.test.com/abc')
self.assertNoMatchBoth('http://test.com:5000/abc/',
'http://www.test.com/abc')
self.assertNoMatchBoth('http://test.com/abc/',
'http://www.test.com:5000/abc')
def test_subset_match(self):
self.assertMatch('/path', 'http://www.test.com/path')
self.assertMatch('/path', 'http://www.test.com/path')
self.assertMatch('//www.test.com/path', 'http://www.test.com/path')
self.assertMatch('//www.test.com/path', 'https://www.test.com/path')
def test_query_string(self):
self.assertMatch('/path?a=1&b=2',
'http://www.test.com/path?a=1&b=2')
self.assertMatch('/path?a=1',
'http://www.test.com/path?a=1&b=2',
complete_qs=False)
self.assertNoMatch('/path?a=1',
'http://www.test.com/path?a=1&b=2',
complete_qs=True)
self.assertNoMatch('/path?a=1&b=2',
'http://www.test.com/path?a=1')