# Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. import json import requests from requests.adapters import BaseAdapter import six from six.moves.urllib import parse as urlparse from requests_mock import exceptions from requests_mock import response ANY = object() class _RequestObjectProxy(object): """A wrapper around a requests.Request that gives some extra information. This will be important both for matching and so that when it's save into the request_history users will be able to access these properties. """ def __init__(self, request): self._request = request self._url_parts_ = None self._qs = None def __getattr__(self, name): return getattr(self._request, name) @property def _url_parts(self): if self._url_parts_ is None: self._url_parts_ = urlparse.urlparse(self._request.url.lower()) return self._url_parts_ @property def scheme(self): return self._url_parts.scheme @property def netloc(self): return self._url_parts.netloc @property def path(self): return self._url_parts.path @property def query(self): return self._url_parts.query @property def qs(self): if self._qs is None: self._qs = urlparse.parse_qs(self.query) return self._qs @classmethod def _create(cls, *args, **kwargs): return cls(requests.Request(*args, **kwargs).prepare()) @property def text(self): body = self.body if isinstance(body, six.binary_type): body = body.decode('utf-8') return body def json(self, **kwargs): return json.loads(self.text, **kwargs) class _RequestHistoryTracker(object): def __init__(self): self.request_history = [] def _add_to_history(self, request): self.request_history.append(request) @property def last_request(self): """Retrieve the latest request sent""" try: return self.request_history[-1] except IndexError: return None @property def called(self): return self.call_count > 0 @property def call_count(self): return len(self.request_history) class _Matcher(_RequestHistoryTracker): """Contains all the information about a provided URL to match.""" def __init__(self, method, url, responses, complete_qs, request_headers): """ :param bool complete_qs: Match the entire query string. By default URLs match if all the provided matcher query arguments are matched and extra query arguments are ignored. Set complete_qs to true to require that the entire query string needs to match. """ super(_Matcher, self).__init__() self._method = method self._url = url try: self._url_parts = urlparse.urlparse(url.lower()) except: self._url_parts = None self._responses = responses self._complete_qs = complete_qs self._request_headers = request_headers def _match_method(self, request): if self._method is ANY: return True if request.method.lower() == self._method.lower(): return True return False def _match_url(self, request): if self._url is ANY: return True # regular expression matching if hasattr(self._url, 'search'): return self._url.search(request.url) is not None if self._url_parts.scheme and request.scheme != self._url_parts.scheme: return False if self._url_parts.netloc and request.netloc != self._url_parts.netloc: return False if (request.path or '/') != (self._url_parts.path or '/'): return False # construct our own qs structure as we remove items from it below request_qs = urlparse.parse_qs(request.query) matcher_qs = urlparse.parse_qs(self._url_parts.query) for k, vals in six.iteritems(matcher_qs): for v in vals: try: request_qs.get(k, []).remove(v) except ValueError: return False if self._complete_qs: for v in six.itervalues(request_qs): if v: return False return True def _match_headers(self, request): for k, vals in six.iteritems(self._request_headers): try: header = request.headers[k] except KeyError: return False else: if header != vals: return False return True def _match(self, request): return (self._match_method(request) and self._match_url(request) and self._match_headers(request)) def __call__(self, request): if not self._match(request): return None if len(self._responses) > 1: response_matcher = self._responses.pop(0) else: response_matcher = self._responses[0] self._add_to_history(request) return response_matcher.get_response(request) class Adapter(BaseAdapter, _RequestHistoryTracker): """A fake adapter than can return predefined responses. """ def __init__(self): super(Adapter, self).__init__() self._matchers = [] def send(self, request, **kwargs): request = _RequestObjectProxy(request) self._add_to_history(request) for matcher in reversed(self._matchers): resp = matcher(request) if resp is not None: resp.connection = self return resp raise exceptions.NoMockAddress(request) def close(self): pass def register_uri(self, method, url, response_list=None, **kwargs): """Register a new URI match and fake response. :param str method: The HTTP method to match. :param str url: The URL to match. """ complete_qs = kwargs.pop('complete_qs', False) request_headers = kwargs.pop('request_headers', {}) if response_list and kwargs: raise RuntimeError('You should specify either a list of ' 'responses OR response kwargs. Not both.') elif not response_list: response_list = [kwargs] responses = [response._MatcherResponse(**k) for k in response_list] matcher = _Matcher(method, url, responses, complete_qs=complete_qs, request_headers=request_headers) self.add_matcher(matcher) return matcher def add_matcher(self, matcher): """Register a custom matcher. A matcher is a callable that takes a `requests.Request` and returns a `requests.Response` if it matches or None if not. :param callable matcher: The matcher to execute. """ self._matchers.append(matcher) __all__ = ['Adapter']