More reorganizing
Added response lists, new tests, renamed and changed stuff.
This commit is contained in:
parent
e01f313529
commit
81439f4fa2
@ -14,20 +14,26 @@ import json as jsonutils
|
||||
|
||||
from requests.adapters import BaseAdapter, HTTPAdapter
|
||||
from requests.packages.urllib3.response import HTTPResponse
|
||||
from six import BytesIO
|
||||
import six
|
||||
from six.moves.urllib import parse as urlparse
|
||||
|
||||
from requests_mock import exceptions
|
||||
|
||||
|
||||
class Context(object):
|
||||
class _Context(object):
|
||||
"""Stores the data being used to process a current URL match."""
|
||||
|
||||
def __init__(self, request, headers, status_code):
|
||||
self.request = request
|
||||
self.headers = headers.copy()
|
||||
self.headers = headers
|
||||
self.status_code = status_code
|
||||
|
||||
def call(self, f, *args, **kwargs):
|
||||
"""Test and call a callback if one was provided.
|
||||
|
||||
If the object provided is callable, then call it and return otherwise
|
||||
just return the object.
|
||||
"""
|
||||
if callable(f):
|
||||
status_code, headers, data = f(self.request, *args, **kwargs)
|
||||
if status_code:
|
||||
@ -39,40 +45,59 @@ class Context(object):
|
||||
return f
|
||||
|
||||
|
||||
class Matcher(object):
|
||||
class _MatcherResponse(object):
|
||||
|
||||
_http_adapter = HTTPAdapter()
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
:param int status_code: The status code to return upon a successful
|
||||
match. Defaults to 200.
|
||||
:param HTTPResponse raw: A HTTPResponse object to return upon a
|
||||
successful match.
|
||||
:param io.IOBase body: An IO object with a read() method that can
|
||||
return a body on successful match.
|
||||
:param bytes content: A byte string to return upon a successful match.
|
||||
:param unicode text: A text string to return upon a successful match.
|
||||
:param object json: A python object to be converted to a JSON string
|
||||
and returned upon a successful match.
|
||||
:param dict headers: A dictionary object containing headers that are
|
||||
returned upon a successful match.
|
||||
"""
|
||||
self.status_code = kwargs.pop('status_code', 200)
|
||||
self.raw = kwargs.pop('raw', None)
|
||||
self.body = kwargs.pop('body', None)
|
||||
self.content = kwargs.pop('content', None)
|
||||
self.text = kwargs.pop('text', None)
|
||||
self.json = kwargs.pop('json', None)
|
||||
self.headers = kwargs.pop('headers', {})
|
||||
|
||||
def __init__(self,
|
||||
method,
|
||||
url,
|
||||
status_code=200,
|
||||
raw=None,
|
||||
body=None,
|
||||
content=None,
|
||||
text=None,
|
||||
json=None,
|
||||
headers=None):
|
||||
self.method = method
|
||||
self.url = urlparse.urlsplit(url.lower())
|
||||
self.status_code = status_code
|
||||
self.raw = raw
|
||||
self.body = body
|
||||
self.content = content
|
||||
self.text = text
|
||||
self.json = json
|
||||
self.headers = headers or {}
|
||||
if kwargs:
|
||||
raise TypeError('Too many arguments provided to _MatcherResponse')
|
||||
|
||||
if sum([bool(x) for x in (self.raw,
|
||||
self.body,
|
||||
self.content,
|
||||
self.text,
|
||||
self.json)]) > 1:
|
||||
# mutual exclusion, only 1 body method may be provided
|
||||
provided = sum([bool(x) for x in (self.raw,
|
||||
self.body,
|
||||
self.content,
|
||||
self.text,
|
||||
self.json)])
|
||||
|
||||
if provided == 0:
|
||||
self.body = six.BytesIO(six.b(''))
|
||||
elif provided > 1:
|
||||
raise RuntimeError('You may only supply one body element')
|
||||
|
||||
def _get_response(self, request):
|
||||
# whilst in general you shouldn't do type checking in python this
|
||||
# makes sure we don't end up with differences between the way types
|
||||
# are handled between python 2 and 3.
|
||||
if self.content and not (callable(self.content) or
|
||||
isinstance(self.content, six.binary_type)):
|
||||
raise TypeError('Content should be a callback or binary data')
|
||||
if self.text and not (callable(self.text) or
|
||||
isinstance(self.text, six.string_types)):
|
||||
raise TypeError('Text should be a callback or string data')
|
||||
|
||||
def get_response(self, request):
|
||||
encoding = None
|
||||
context = Context(request, self.headers, self.status_code)
|
||||
context = _Context(request, self.headers.copy(), self.status_code)
|
||||
|
||||
content = self.content
|
||||
text = self.text
|
||||
@ -88,7 +113,7 @@ class Matcher(object):
|
||||
content = data.encode(encoding)
|
||||
if content:
|
||||
data = context.call(content)
|
||||
body = BytesIO(data)
|
||||
body = six.BytesIO(data)
|
||||
if body:
|
||||
data = context.call(body)
|
||||
raw = HTTPResponse(status=context.status_code,
|
||||
@ -99,8 +124,31 @@ class Matcher(object):
|
||||
|
||||
return encoding, raw
|
||||
|
||||
|
||||
class _Matcher(object):
|
||||
"""Contains all the information about a provided URL to match."""
|
||||
|
||||
_http_adapter = HTTPAdapter()
|
||||
|
||||
def __init__(self, method, url, responses, complete_qs):
|
||||
"""
|
||||
:param bool complete_qs: Match the entire query string. By default URLs
|
||||
match if all the provided matcher query arguments are matched and
|
||||
extra query arguments are ignored. Set complete_qs to true to
|
||||
require that the entire query string needs to match.
|
||||
"""
|
||||
self.method = method
|
||||
self.url = urlparse.urlsplit(url.lower())
|
||||
self.responses = responses
|
||||
self.complete_qs = complete_qs
|
||||
|
||||
def create_response(self, request):
|
||||
encoding, response = self._get_response(request)
|
||||
if len(self.responses) > 1:
|
||||
response_matcher = self.responses.pop(0)
|
||||
else:
|
||||
response_matcher = self.responses[0]
|
||||
|
||||
encoding, response = response_matcher.get_response(request)
|
||||
req_resp = self._http_adapter.build_response(request, response)
|
||||
req_resp.connection = self
|
||||
req_resp.encoding = encoding
|
||||
@ -109,36 +157,51 @@ class Matcher(object):
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def match(self, request):
|
||||
if request.method.lower() != self.method.lower():
|
||||
return False
|
||||
|
||||
def match(request, matcher):
|
||||
if request.method.lower() != matcher.method.lower():
|
||||
return False
|
||||
url = urlparse.urlsplit(request.url.lower())
|
||||
|
||||
url = urlparse.urlsplit(request.url.lower())
|
||||
if self.url.scheme and url.scheme != self.url.scheme:
|
||||
return False
|
||||
|
||||
if matcher.url.scheme and url.scheme != matcher.url.scheme:
|
||||
return False
|
||||
if self.url.netloc and url.netloc != self.url.netloc:
|
||||
return False
|
||||
|
||||
if matcher.url.netloc and url.netloc != matcher.url.netloc:
|
||||
return False
|
||||
if (url.path or '/') != (self.url.path or '/'):
|
||||
return False
|
||||
|
||||
if matcher.url.path and url.path != matcher.url.path:
|
||||
return False
|
||||
matcher_qs = urlparse.parse_qs(self.url.query)
|
||||
request_qs = urlparse.parse_qs(url.query)
|
||||
|
||||
return True
|
||||
for k, vals in six.iteritems(matcher_qs):
|
||||
for v in vals:
|
||||
try:
|
||||
request_qs.get(k, []).remove(v)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
if self.complete_qs:
|
||||
for v in six.itervalues(request_qs):
|
||||
if v:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class Adapter(BaseAdapter):
|
||||
"""A fake adapter than can return predefined responses.
|
||||
|
||||
"""
|
||||
def __init__(self):
|
||||
self._matchers = []
|
||||
self._request_history = []
|
||||
self.request_history = []
|
||||
|
||||
def send(self, request, **kwargs):
|
||||
self._request_history.append(request)
|
||||
|
||||
for matcher in self._matchers:
|
||||
if match(request, matcher):
|
||||
if matcher.match(request):
|
||||
self.request_history.append(request)
|
||||
return matcher.create_response(request)
|
||||
|
||||
raise exceptions.NoMockAddress(request)
|
||||
@ -146,13 +209,31 @@ class Adapter(BaseAdapter):
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def register_uri(self, *args, **kwargs):
|
||||
self._matchers.append(Matcher(*args, **kwargs))
|
||||
def register_uri(self, method, url, request_list=None, **kwargs):
|
||||
"""Register a new URI match and fake response.
|
||||
|
||||
:param str method: The HTTP method to match.
|
||||
:param str url: The URL to match.
|
||||
"""
|
||||
complete_qs = kwargs.pop('complete_qs', False)
|
||||
|
||||
if request_list and kwargs:
|
||||
raise RuntimeError('You should specify either a list of '
|
||||
'responses OR response kwargs. Not both.')
|
||||
elif not request_list:
|
||||
request_list = [kwargs]
|
||||
|
||||
responses = [_MatcherResponse(**k) for k in request_list]
|
||||
self._matchers.append(_Matcher(method,
|
||||
url,
|
||||
responses,
|
||||
complete_qs=complete_qs))
|
||||
|
||||
@property
|
||||
def last_request(self):
|
||||
"""Retrieve the latest request sent"""
|
||||
try:
|
||||
return self._request_history[-1]
|
||||
return self.request_history[-1]
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
|
@ -36,20 +36,19 @@ class SessionAdapterTests(base.TestCase):
|
||||
self.assertEqual(v, resp.headers[k])
|
||||
|
||||
def assertLastRequest(self, method='GET', body=None):
|
||||
last = self.adapter.last_request()
|
||||
self.assertEqual(self.url, last.url)
|
||||
self.assertEqual(method, last.method)
|
||||
self.assertEqual(body, last.body)
|
||||
self.assertEqual(self.url, self.adapter.last_request.url)
|
||||
self.assertEqual(method, self.adapter.last_request.method)
|
||||
self.assertEqual(body, self.adapter.last_request.body)
|
||||
|
||||
def test_content(self):
|
||||
data = 'testdata'
|
||||
data = six.b('testdata')
|
||||
|
||||
self.adapter.register_uri('GET',
|
||||
self.url,
|
||||
content=data,
|
||||
headers=self.headers)
|
||||
resp = self.session.get(self.url)
|
||||
self.assertEqual(six.b(data), resp.content)
|
||||
self.assertEqual(data, resp.content)
|
||||
self.assertHeaders(resp)
|
||||
self.assertLastRequest()
|
||||
|
||||
@ -131,6 +130,12 @@ class SessionAdapterTests(base.TestCase):
|
||||
self.assertHeaders(resp)
|
||||
self.assertLastRequest()
|
||||
|
||||
def test_no_body(self):
|
||||
self.adapter.register_uri('GET', self.url)
|
||||
resp = self.session.get(self.url)
|
||||
self.assertEqual(six.b(''), resp.content)
|
||||
self.assertEqual(200, resp.status_code)
|
||||
|
||||
def test_multiple_body_elements(self):
|
||||
self.assertRaises(RuntimeError,
|
||||
self.adapter.register_uri,
|
||||
@ -138,3 +143,19 @@ class SessionAdapterTests(base.TestCase):
|
||||
'GET',
|
||||
content=six.b('b'),
|
||||
text=six.u('u'))
|
||||
|
||||
def test_multiple_responses(self):
|
||||
inp = [{'status_code': 400, 'text': 'abcd'},
|
||||
{'status_code': 300, 'text': 'defg'},
|
||||
{'status_code': 200, 'text': 'hijk'}]
|
||||
|
||||
self.adapter.register_uri('GET', self.url, inp)
|
||||
out = [self.session.get(self.url) for i in range(0, len(inp))]
|
||||
|
||||
for i, o in zip(inp, out):
|
||||
for k, v in six.iteritems(i):
|
||||
self.assertEqual(v, getattr(o, k))
|
||||
|
||||
last = self.session.get(self.url)
|
||||
for k, v in six.iteritems(inp[-1]):
|
||||
self.assertEqual(v, getattr(last, k))
|
||||
|
@ -36,4 +36,4 @@ class MockingTests(base.TestCase):
|
||||
|
||||
resp = requests.get(test_url)
|
||||
self.assertEqual('response', resp.text)
|
||||
self.assertEqual(test_url, self.mocker.last_request().url)
|
||||
self.assertEqual(test_url, self.mocker.last_request.url)
|
||||
|
85
requests_mock/tests/test_matcher.py
Normal file
85
requests_mock/tests/test_matcher.py
Normal file
@ -0,0 +1,85 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import requests
|
||||
|
||||
from requests_mock import adapter
|
||||
from requests_mock.tests import base
|
||||
|
||||
|
||||
class TestMatcher(base.TestCase):
|
||||
|
||||
def match(self, target, url, complete_qs=False):
|
||||
matcher = adapter._Matcher('GET', target, [], complete_qs)
|
||||
request = requests.Request('GET', url).prepare()
|
||||
return matcher.match(request)
|
||||
|
||||
def assertMatch(self, target, url, **kwargs):
|
||||
self.assertEqual(True, self.match(target, url, **kwargs),
|
||||
'Matcher %s failed to match %s' % (target, url))
|
||||
|
||||
def assertMatchBoth(self, target, url, **kwargs):
|
||||
self.assertMatch(target, url, **kwargs)
|
||||
self.assertMatch(url, target, **kwargs)
|
||||
|
||||
def assertNoMatch(self, target, url, **kwargs):
|
||||
self.assertEqual(False, self.match(target, url, **kwargs),
|
||||
'Matcher %s unexpectedly matched %s' % (target, url))
|
||||
|
||||
def assertNoMatchBoth(self, target, url, **kwargs):
|
||||
self.assertNoMatch(target, url, **kwargs)
|
||||
self.assertNoMatch(url, target, **kwargs)
|
||||
|
||||
def test_url_matching(self):
|
||||
self.assertMatchBoth('http://www.test.com',
|
||||
'http://www.test.com')
|
||||
self.assertMatchBoth('http://www.test.com',
|
||||
'http://www.test.com/')
|
||||
self.assertMatchBoth('http://www.test.com/abc',
|
||||
'http://www.test.com/abc')
|
||||
self.assertMatchBoth('http://www.test.com:5000/abc',
|
||||
'http://www.test.com:5000/abc')
|
||||
|
||||
self.assertNoMatchBoth('https://www.test.com',
|
||||
'http://www.test.com')
|
||||
self.assertNoMatchBoth('http://www.test.com/abc',
|
||||
'http://www.test.com')
|
||||
self.assertNoMatchBoth('http://test.com',
|
||||
'http://www.test.com')
|
||||
self.assertNoMatchBoth('http://test.com',
|
||||
'http://www.test.com')
|
||||
self.assertNoMatchBoth('http://test.com/abc',
|
||||
'http://www.test.com/abc/')
|
||||
self.assertNoMatchBoth('http://test.com/abc/',
|
||||
'http://www.test.com/abc')
|
||||
self.assertNoMatchBoth('http://test.com:5000/abc/',
|
||||
'http://www.test.com/abc')
|
||||
self.assertNoMatchBoth('http://test.com/abc/',
|
||||
'http://www.test.com:5000/abc')
|
||||
|
||||
def test_subset_match(self):
|
||||
self.assertMatch('/path', 'http://www.test.com/path')
|
||||
self.assertMatch('/path', 'http://www.test.com/path')
|
||||
self.assertMatch('//www.test.com/path', 'http://www.test.com/path')
|
||||
self.assertMatch('//www.test.com/path', 'https://www.test.com/path')
|
||||
|
||||
def test_query_string(self):
|
||||
self.assertMatch('/path?a=1&b=2',
|
||||
'http://www.test.com/path?a=1&b=2')
|
||||
self.assertMatch('/path?a=1',
|
||||
'http://www.test.com/path?a=1&b=2',
|
||||
complete_qs=False)
|
||||
self.assertNoMatch('/path?a=1',
|
||||
'http://www.test.com/path?a=1&b=2',
|
||||
complete_qs=True)
|
||||
self.assertNoMatch('/path?a=1&b=2',
|
||||
'http://www.test.com/path?a=1')
|
Loading…
x
Reference in New Issue
Block a user