# Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. import copy import json import weakref import requests from requests.adapters import BaseAdapter import six from six.moves.urllib import parse as urlparse from requests_mock import exceptions from requests_mock import response ANY = object() class _RequestObjectProxy(object): """A wrapper around a requests.Request that gives some extra information. This will be important both for matching and so that when it's save into the request_history users will be able to access these properties. """ def __init__(self, request, **kwargs): self._request = request self._matcher = None self._url_parts_ = None self._qs = None # All of these params should always exist but we use a default # to make the test setup easier. self._timeout = kwargs.pop('timeout', None) self._allow_redirects = kwargs.pop('allow_redirects', None) self._verify = kwargs.pop('verify', None) self._cert = kwargs.pop('cert', None) self._proxies = copy.deepcopy(kwargs.pop('proxies', {})) def __getattr__(self, name): return getattr(self._request, name) @property def _url_parts(self): if self._url_parts_ is None: self._url_parts_ = urlparse.urlparse(self._request.url.lower()) return self._url_parts_ @property def scheme(self): return self._url_parts.scheme @property def netloc(self): return self._url_parts.netloc @property def path(self): return self._url_parts.path @property def query(self): return self._url_parts.query @property def qs(self): if self._qs is None: self._qs = urlparse.parse_qs(self.query) return self._qs @property def timeout(self): return self._timeout @property def allow_redirects(self): return self._allow_redirects @property def verify(self): return self._verify @property def cert(self): return self._cert @property def proxies(self): return self._proxies @classmethod def _create(cls, *args, **kwargs): return cls(requests.Request(*args, **kwargs).prepare()) @property def text(self): body = self.body if isinstance(body, six.binary_type): body = body.decode('utf-8') return body def json(self, **kwargs): return json.loads(self.text, **kwargs) @property def matcher(self): """The matcher that this request was handled by. The matcher object is handled by a weakref. It will return the matcher object if it is still available - so if the mock is still in place. If the matcher is not available it will return None. """ return self._matcher() class _RequestHistoryTracker(object): def __init__(self): self.request_history = [] def _add_to_history(self, request): self.request_history.append(request) @property def last_request(self): """Retrieve the latest request sent""" try: return self.request_history[-1] except IndexError: return None @property def called(self): return self.call_count > 0 @property def call_count(self): return len(self.request_history) class _Matcher(_RequestHistoryTracker): """Contains all the information about a provided URL to match.""" def __init__(self, method, url, responses, complete_qs, request_headers): """ :param bool complete_qs: Match the entire query string. By default URLs match if all the provided matcher query arguments are matched and extra query arguments are ignored. Set complete_qs to true to require that the entire query string needs to match. """ super(_Matcher, self).__init__() self._method = method self._url = url try: self._url_parts = urlparse.urlparse(url.lower()) except: self._url_parts = None self._responses = responses self._complete_qs = complete_qs self._request_headers = request_headers def _match_method(self, request): if self._method is ANY: return True if request.method.lower() == self._method.lower(): return True return False def _match_url(self, request): if self._url is ANY: return True # regular expression matching if hasattr(self._url, 'search'): return self._url.search(request.url) is not None if self._url_parts.scheme and request.scheme != self._url_parts.scheme: return False if self._url_parts.netloc and request.netloc != self._url_parts.netloc: return False if (request.path or '/') != (self._url_parts.path or '/'): return False # construct our own qs structure as we remove items from it below request_qs = urlparse.parse_qs(request.query) matcher_qs = urlparse.parse_qs(self._url_parts.query) for k, vals in six.iteritems(matcher_qs): for v in vals: try: request_qs.get(k, []).remove(v) except ValueError: return False if self._complete_qs: for v in six.itervalues(request_qs): if v: return False return True def _match_headers(self, request): for k, vals in six.iteritems(self._request_headers): try: header = request.headers[k] except KeyError: # NOTE(jamielennox): This seems to be a requests 1.2/2 # difference, in 2 they are just whatever the user inputted in # 1 they are bytes. Let's optionally handle both and look at # removing this when we depend on requests 2. if not isinstance(k, six.text_type): return False try: header = request.headers[k.encode('utf-8')] except KeyError: return False if header != vals: return False return True def _match(self, request): return (self._match_method(request) and self._match_url(request) and self._match_headers(request)) def __call__(self, request): if not self._match(request): return None if len(self._responses) > 1: response_matcher = self._responses.pop(0) else: response_matcher = self._responses[0] self._add_to_history(request) return response_matcher.get_response(request) class Adapter(BaseAdapter, _RequestHistoryTracker): """A fake adapter than can return predefined responses. """ def __init__(self): super(Adapter, self).__init__() self._matchers = [] def send(self, request, **kwargs): request = _RequestObjectProxy(request, **kwargs) self._add_to_history(request) for matcher in reversed(self._matchers): try: resp = matcher(request) except Exception: request._matcher = weakref.ref(matcher) raise if resp is not None: request._matcher = weakref.ref(matcher) resp.connection = self return resp raise exceptions.NoMockAddress(request) def close(self): pass def register_uri(self, method, url, response_list=None, **kwargs): """Register a new URI match and fake response. :param str method: The HTTP method to match. :param str url: The URL to match. """ complete_qs = kwargs.pop('complete_qs', False) request_headers = kwargs.pop('request_headers', {}) if response_list and kwargs: raise RuntimeError('You should specify either a list of ' 'responses OR response kwargs. Not both.') elif not response_list: response_list = [kwargs] responses = [response._MatcherResponse(**k) for k in response_list] matcher = _Matcher(method, url, responses, complete_qs=complete_qs, request_headers=request_headers) self.add_matcher(matcher) return matcher def add_matcher(self, matcher): """Register a custom matcher. A matcher is a callable that takes a `requests.Request` and returns a `requests.Response` if it matches or None if not. :param callable matcher: The matcher to execute. """ self._matchers.append(matcher) __all__ = ['Adapter']