From e31045e57a102182d33252e3a6b07ddfa9488ebe Mon Sep 17 00:00:00 2001 From: Chris Dent Date: Fri, 10 Apr 2015 18:41:32 +0100 Subject: [PATCH] Multiple protocol accept or content-type matching The changes in 8710dabb652dae775dee31789e91608f832e62e6 broke protocol failover when the REST protocol is listed before others (see bug referenced below). This patch tries to solve both issues by trying to match accept over all the protocols, only giving up a 406 or 415 if all protocols fail, using the last failure as the error message. Related-Bug: #1419110 Closes-Bug: #1442710 Change-Id: I328a392151013c46207519c245213d5dec48ecc9 --- wsme/protocol.py | 3 +- wsme/root.py | 25 +++++++----- wsme/tests/test_api.py | 1 + wsme/tests/test_root.py | 84 ++++++++++++++++++++++++++++++++++++++++- 4 files changed, 100 insertions(+), 13 deletions(-) diff --git a/wsme/protocol.py b/wsme/protocol.py index ec40093..b0107ab 100644 --- a/wsme/protocol.py +++ b/wsme/protocol.py @@ -117,7 +117,7 @@ def getprotocol(name, **options): def media_type_accept(request, content_types): - """Return True if the requested media type is available. + """Validate media types against request.method. When request.method is GET or HEAD compare with the Accept header. When request.method is POST, PUT or PATCH compare with the Content-Type @@ -131,7 +131,6 @@ def media_type_accept(request, content_types): error_message = ('Unacceptable Accept type: %s not in %s' % (request.accept, content_types)) raise ClientSideError(error_message, status_code=406) - return False elif request.method in ['PUT', 'POST', 'PATCH']: content_type = request.headers.get('Content-Type') if content_type: diff --git a/wsme/root.py b/wsme/root.py index 4b56c4f..1ccfbdf 100644 --- a/wsme/root.py +++ b/wsme/root.py @@ -149,6 +149,7 @@ class WSRoot(object): request.body[:512] or request.body) or '') protocol = None + error = ClientSideError(status_code=406) path = str(request.path) assert path.startswith(self._webpath) path = path[len(self._webpath) + 1:] @@ -157,9 +158,16 @@ class WSRoot(object): else: for p in self.protocols: - if p.accept(request): - protocol = p - break + try: + if p.accept(request): + protocol = p + break + except ClientSideError as e: + error = e + # If we could not select a protocol, we raise the last exception + # that we got, or the default one. + if not protocol: + raise error return protocol def _do_call(self, protocol, context): @@ -232,11 +240,6 @@ class WSRoot(object): msg = None error_status = 500 protocol = self._select_protocol(request) - if protocol is None: - if request.method in ['GET', 'HEAD']: - error_status = 406 - elif request.method in ['POST', 'PUT', 'PATCH']: - error_status = 415 except ClientSideError as e: error_status = e.code msg = e.faultstring @@ -248,7 +251,7 @@ class WSRoot(object): error_status = 500 if protocol is None: - if msg is None: + if not msg: msg = ("None of the following protocols can handle this " "request : %s" % ','.join(( p.name for p in self.protocols))) @@ -296,6 +299,10 @@ class WSRoot(object): else: res.status = protocol.get_response_status(request) res_content_type = protocol.get_response_contenttype(request) + except ClientSideError as e: + request.server_errorcount += 1 + res.status = e.code + res.text = e.faultstring except Exception: infos = wsme.api.format_exception(sys.exc_info(), self._debug) request.server_errorcount += 1 diff --git a/wsme/tests/test_api.py b/wsme/tests/test_api.py index 9d02807..b99b20a 100644 --- a/wsme/tests/test_api.py +++ b/wsme/tests/test_api.py @@ -201,6 +201,7 @@ Value should be one of:")) app = webtest.TestApp(r.wsgiapp()) res = app.get('/', expect_errors=True) + print(res.status_int) assert res.status_int == 406 print(res.body) assert res.body.find( diff --git a/wsme/tests/test_root.py b/wsme/tests/test_root.py index 4c68375..8ccb0c1 100644 --- a/wsme/tests/test_root.py +++ b/wsme/tests/test_root.py @@ -3,9 +3,12 @@ import unittest from wsme import WSRoot +import wsme.protocol +import wsme.rest.protocol from wsme.root import default_prepare_response_body from six import b, u +from webob import Request class TestRoot(unittest.TestCase): @@ -24,9 +27,9 @@ class TestRoot(unittest.TestCase): default_prepare_response_body(None, [u('a'), u('b')]) == u('a\nb') def test_protocol_selection_error(self): - import wsme.protocol - class P(wsme.protocol.Protocol): + name = "test" + def accept(self, r): raise Exception('test') @@ -40,3 +43,80 @@ class TestRoot(unittest.TestCase): assert res.content_type == 'text/plain' assert (res.text == 'Unexpected error while selecting protocol: test'), req.text + + def test_protocol_selection_accept_mismatch(self): + """Verify that we get a 406 error on wrong Accept header.""" + class P(wsme.protocol.Protocol): + name = "test" + + def accept(self, r): + return False + + root = WSRoot() + root.addprotocol(wsme.rest.protocol.RestProtocol()) + root.addprotocol(P()) + + req = Request.blank('/test?check=a&check=b&name=Bob') + req.method = 'GET' + res = root._handle_request(req) + assert res.status_int == 406 + assert res.content_type == 'text/plain' + assert res.text.startswith( + 'None of the following protocols can handle this request' + ), req.text + + def test_protocol_selection_content_type_mismatch(self): + """Verify that we get a 415 error on wrong Content-Type header.""" + class P(wsme.protocol.Protocol): + name = "test" + + def accept(self, r): + return False + + root = WSRoot() + root.addprotocol(wsme.rest.protocol.RestProtocol()) + root.addprotocol(P()) + + req = Request.blank('/test?check=a&check=b&name=Bob') + req.method = 'POST' + req.headers['Content-Type'] = "test/unsupported" + res = root._handle_request(req) + assert res.status_int == 415 + assert res.content_type == 'text/plain' + assert res.text.startswith( + 'Unacceptable Content-Type: test/unsupported not in' + ), req.text + + def test_protocol_selection_get_method(self): + class P(wsme.protocol.Protocol): + name = "test" + + def accept(self, r): + return True + + root = WSRoot() + root.addprotocol(wsme.rest.protocol.RestProtocol()) + root.addprotocol(P()) + + req = Request.blank('/test?check=a&check=b&name=Bob') + req.method = 'GET' + req.headers['Accept'] = 'test/fake' + p = root._select_protocol(req) + assert p.name == "test" + + def test_protocol_selection_post_method(self): + class P(wsme.protocol.Protocol): + name = "test" + + def accept(self, r): + return True + + root = WSRoot() + root.addprotocol(wsme.rest.protocol.RestProtocol()) + root.addprotocol(P()) + + req = Request.blank('/test?check=a&check=b&name=Bob') + req.headers['Content-Type'] = 'test/fake' + req.method = 'POST' + p = root._select_protocol(req) + assert p.name == "test"