diff --git a/wsme/protocol.py b/wsme/protocol.py index ec40093..b0107ab 100644 --- a/wsme/protocol.py +++ b/wsme/protocol.py @@ -117,7 +117,7 @@ def getprotocol(name, **options): def media_type_accept(request, content_types): - """Return True if the requested media type is available. + """Validate media types against request.method. When request.method is GET or HEAD compare with the Accept header. When request.method is POST, PUT or PATCH compare with the Content-Type @@ -131,7 +131,6 @@ def media_type_accept(request, content_types): error_message = ('Unacceptable Accept type: %s not in %s' % (request.accept, content_types)) raise ClientSideError(error_message, status_code=406) - return False elif request.method in ['PUT', 'POST', 'PATCH']: content_type = request.headers.get('Content-Type') if content_type: diff --git a/wsme/root.py b/wsme/root.py index 4b56c4f..1ccfbdf 100644 --- a/wsme/root.py +++ b/wsme/root.py @@ -149,6 +149,7 @@ class WSRoot(object): request.body[:512] or request.body) or '') protocol = None + error = ClientSideError(status_code=406) path = str(request.path) assert path.startswith(self._webpath) path = path[len(self._webpath) + 1:] @@ -157,9 +158,16 @@ class WSRoot(object): else: for p in self.protocols: - if p.accept(request): - protocol = p - break + try: + if p.accept(request): + protocol = p + break + except ClientSideError as e: + error = e + # If we could not select a protocol, we raise the last exception + # that we got, or the default one. + if not protocol: + raise error return protocol def _do_call(self, protocol, context): @@ -232,11 +240,6 @@ class WSRoot(object): msg = None error_status = 500 protocol = self._select_protocol(request) - if protocol is None: - if request.method in ['GET', 'HEAD']: - error_status = 406 - elif request.method in ['POST', 'PUT', 'PATCH']: - error_status = 415 except ClientSideError as e: error_status = e.code msg = e.faultstring @@ -248,7 +251,7 @@ class WSRoot(object): error_status = 500 if protocol is None: - if msg is None: + if not msg: msg = ("None of the following protocols can handle this " "request : %s" % ','.join(( p.name for p in self.protocols))) @@ -296,6 +299,10 @@ class WSRoot(object): else: res.status = protocol.get_response_status(request) res_content_type = protocol.get_response_contenttype(request) + except ClientSideError as e: + request.server_errorcount += 1 + res.status = e.code + res.text = e.faultstring except Exception: infos = wsme.api.format_exception(sys.exc_info(), self._debug) request.server_errorcount += 1 diff --git a/wsme/tests/test_api.py b/wsme/tests/test_api.py index 9d02807..b99b20a 100644 --- a/wsme/tests/test_api.py +++ b/wsme/tests/test_api.py @@ -201,6 +201,7 @@ Value should be one of:")) app = webtest.TestApp(r.wsgiapp()) res = app.get('/', expect_errors=True) + print(res.status_int) assert res.status_int == 406 print(res.body) assert res.body.find( diff --git a/wsme/tests/test_root.py b/wsme/tests/test_root.py index 4c68375..8ccb0c1 100644 --- a/wsme/tests/test_root.py +++ b/wsme/tests/test_root.py @@ -3,9 +3,12 @@ import unittest from wsme import WSRoot +import wsme.protocol +import wsme.rest.protocol from wsme.root import default_prepare_response_body from six import b, u +from webob import Request class TestRoot(unittest.TestCase): @@ -24,9 +27,9 @@ class TestRoot(unittest.TestCase): default_prepare_response_body(None, [u('a'), u('b')]) == u('a\nb') def test_protocol_selection_error(self): - import wsme.protocol - class P(wsme.protocol.Protocol): + name = "test" + def accept(self, r): raise Exception('test') @@ -40,3 +43,80 @@ class TestRoot(unittest.TestCase): assert res.content_type == 'text/plain' assert (res.text == 'Unexpected error while selecting protocol: test'), req.text + + def test_protocol_selection_accept_mismatch(self): + """Verify that we get a 406 error on wrong Accept header.""" + class P(wsme.protocol.Protocol): + name = "test" + + def accept(self, r): + return False + + root = WSRoot() + root.addprotocol(wsme.rest.protocol.RestProtocol()) + root.addprotocol(P()) + + req = Request.blank('/test?check=a&check=b&name=Bob') + req.method = 'GET' + res = root._handle_request(req) + assert res.status_int == 406 + assert res.content_type == 'text/plain' + assert res.text.startswith( + 'None of the following protocols can handle this request' + ), req.text + + def test_protocol_selection_content_type_mismatch(self): + """Verify that we get a 415 error on wrong Content-Type header.""" + class P(wsme.protocol.Protocol): + name = "test" + + def accept(self, r): + return False + + root = WSRoot() + root.addprotocol(wsme.rest.protocol.RestProtocol()) + root.addprotocol(P()) + + req = Request.blank('/test?check=a&check=b&name=Bob') + req.method = 'POST' + req.headers['Content-Type'] = "test/unsupported" + res = root._handle_request(req) + assert res.status_int == 415 + assert res.content_type == 'text/plain' + assert res.text.startswith( + 'Unacceptable Content-Type: test/unsupported not in' + ), req.text + + def test_protocol_selection_get_method(self): + class P(wsme.protocol.Protocol): + name = "test" + + def accept(self, r): + return True + + root = WSRoot() + root.addprotocol(wsme.rest.protocol.RestProtocol()) + root.addprotocol(P()) + + req = Request.blank('/test?check=a&check=b&name=Bob') + req.method = 'GET' + req.headers['Accept'] = 'test/fake' + p = root._select_protocol(req) + assert p.name == "test" + + def test_protocol_selection_post_method(self): + class P(wsme.protocol.Protocol): + name = "test" + + def accept(self, r): + return True + + root = WSRoot() + root.addprotocol(wsme.rest.protocol.RestProtocol()) + root.addprotocol(P()) + + req = Request.blank('/test?check=a&check=b&name=Bob') + req.headers['Content-Type'] = 'test/fake' + req.method = 'POST' + p = root._select_protocol(req) + assert p.name == "test"