# Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. import json as jsonutils from requests.adapters import BaseAdapter, HTTPAdapter from requests.packages.urllib3.response import HTTPResponse import six from six.moves.urllib import parse as urlparse from requests_mock import exceptions ANY = object() class _Context(object): """Stores the data being used to process a current URL match.""" def __init__(self, headers, status_code, reason): self.headers = headers self.status_code = status_code self.reason = reason class _RequestHistoryTracker(object): def __init__(self): self.request_history = [] def _add_to_history(self, request): self.request_history.append(request) @property def last_request(self): """Retrieve the latest request sent""" try: return self.request_history[-1] except IndexError: return None @property def called(self): return self.call_count > 0 @property def call_count(self): return len(self.request_history) class _MatcherResponse(object): _BODY_ARGS = ['raw', 'body', 'content', 'text', 'json'] def __init__(self, **kwargs): """ :param int status_code: The status code to return upon a successful match. Defaults to 200. :param HTTPResponse raw: A HTTPResponse object to return upon a successful match. :param io.IOBase body: An IO object with a read() method that can return a body on successful match. :param bytes content: A byte string to return upon a successful match. :param unicode text: A text string to return upon a successful match. :param object json: A python object to be converted to a JSON string and returned upon a successful match. :param dict headers: A dictionary object containing headers that are returned upon a successful match. """ # mutual exclusion, only 1 body method may be provided provided = [x for x in self._BODY_ARGS if kwargs.get(x) is not None] self.status_code = kwargs.pop('status_code', 200) self.raw = kwargs.pop('raw', None) self.body = kwargs.pop('body', None) self.content = kwargs.pop('content', None) self.text = kwargs.pop('text', None) self.json = kwargs.pop('json', None) self.reason = kwargs.pop('reason', None) self.headers = kwargs.pop('headers', {}) if kwargs: raise TypeError('Too many arguments provided. Unexpected ' 'arguments %s.' % ', '.join(kwargs.keys())) if len(provided) == 0: self.body = six.BytesIO(six.b('')) elif len(provided) > 1: raise RuntimeError('You may only supply one body element. You ' 'supplied %s' % ', '.join(provided)) # whilst in general you shouldn't do type checking in python this # makes sure we don't end up with differences between the way types # are handled between python 2 and 3. if self.content and not (callable(self.content) or isinstance(self.content, six.binary_type)): raise TypeError('Content should be a callback or binary data') if self.text and not (callable(self.text) or isinstance(self.text, six.string_types)): raise TypeError('Text should be a callback or string data') def get_response(self, request): encoding = None context = _Context(self.headers.copy(), self.status_code, self.reason) # if a body element is a callback then execute it def _call(f, *args, **kwargs): return f(request, context, *args, **kwargs) if callable(f) else f content = self.content text = self.text body = self.body raw = self.raw if self.json is not None: data = _call(self.json) text = jsonutils.dumps(data) if text is not None: data = _call(text) encoding = 'utf-8' content = data.encode(encoding) if content is not None: data = _call(content) body = six.BytesIO(data) if body is not None: data = _call(body) raw = HTTPResponse(status=context.status_code, body=data, headers=context.headers, reason=context.reason, decode_content=False, preload_content=False) return encoding, raw class _FakeConnection(object): """An object that can mock the necessary parts of a socket interface.""" def close(self): pass class _Matcher(_RequestHistoryTracker): """Contains all the information about a provided URL to match.""" _http_adapter = HTTPAdapter() def __init__(self, method, url, responses, complete_qs, request_headers): """ :param bool complete_qs: Match the entire query string. By default URLs match if all the provided matcher query arguments are matched and extra query arguments are ignored. Set complete_qs to true to require that the entire query string needs to match. """ super(_Matcher, self).__init__() self._method = method self._url = url try: self._url_parts = urlparse.urlparse(url.lower()) except: self._url_parts = None self._responses = responses self._complete_qs = complete_qs self._request_headers = request_headers def _match_method(self, request): if self._method is ANY: return True if request.method.lower() == self._method.lower(): return True return False def _match_url(self, request): if self._url is ANY: return True # regular expression matching if hasattr(self._url, 'search'): return self._url.search(request.url) is not None url = urlparse.urlparse(request.url.lower()) if self._url_parts.scheme and url.scheme != self._url_parts.scheme: return False if self._url_parts.netloc and url.netloc != self._url_parts.netloc: return False if (url.path or '/') != (self._url_parts.path or '/'): return False matcher_qs = urlparse.parse_qs(self._url_parts.query) request_qs = urlparse.parse_qs(url.query) for k, vals in six.iteritems(matcher_qs): for v in vals: try: request_qs.get(k, []).remove(v) except ValueError: return False if self._complete_qs: for v in six.itervalues(request_qs): if v: return False return True def _match_headers(self, request): for k, vals in six.iteritems(self._request_headers): try: header = request.headers[k] except KeyError: return False else: if header != vals: return False return True def _match(self, request): return (self._match_method(request) and self._match_url(request) and self._match_headers(request)) def __call__(self, request): if not self._match(request): return None if len(self._responses) > 1: response_matcher = self._responses.pop(0) else: response_matcher = self._responses[0] self._add_to_history(request) encoding, response = response_matcher.get_response(request) req_resp = self._http_adapter.build_response(request, response) req_resp.connection = _FakeConnection() req_resp.encoding = encoding return req_resp class Adapter(BaseAdapter, _RequestHistoryTracker): """A fake adapter than can return predefined responses. """ def __init__(self): super(Adapter, self).__init__() self._matchers = [] def send(self, request, **kwargs): self._add_to_history(request) for matcher in reversed(self._matchers): response = matcher(request) if response is not None: return response raise exceptions.NoMockAddress(request) def close(self): pass def register_uri(self, method, url, response_list=None, **kwargs): """Register a new URI match and fake response. :param str method: The HTTP method to match. :param str url: The URL to match. """ complete_qs = kwargs.pop('complete_qs', False) request_headers = kwargs.pop('request_headers', {}) if response_list and kwargs: raise RuntimeError('You should specify either a list of ' 'responses OR response kwargs. Not both.') elif not response_list: response_list = [kwargs] responses = [_MatcherResponse(**k) for k in response_list] matcher = _Matcher(method, url, responses, complete_qs=complete_qs, request_headers=request_headers) self.add_matcher(matcher) return matcher def add_matcher(self, matcher): """Register a custom matcher. A matcher is a callable that takes a `requests.Request` and returns a `requests.Response` if it matches or None if not. :param callable matcher: The matcher to execute. """ self._matchers.append(matcher) __all__ = ['Adapter']