From 8710dabb652dae775dee31789e91608f832e62e6 Mon Sep 17 00:00:00 2001 From: Chris Dent Date: Mon, 9 Feb 2015 14:52:07 +0000 Subject: [PATCH] Improve Accept and Content-Type handling Originally, if WSME received an Accept or Content-Type header that was not aligned with what it was prepared to handle it would error out with a 500 status code. This is not good behavior for a web service. In the process of trying to fix this it was discovered that the content-negotiation code within WSME (the code that, in part, looks for a suitable protocol handler for a request) and tests of that code are incorrect, violating expected HTTP behaviors. GET requests are passing Content-Type headers to declare the desired type of representation in the response. This is what Accept is for. Unfortunately the server-side code was perfectly willing to accept this behavior. These changes correct that. Closes-Bug: 1419110 Change-Id: I2b5c0075611490c047b27b1b43b0505fc5534b3b --- wsme/protocol.py | 35 ++++++++++ wsme/rest/protocol.py | 9 +-- wsme/root.py | 20 +++++- wsme/tests/test_api.py | 2 +- wsme/tests/test_restjson.py | 71 +++++++++++++++++++- wsme/tests/test_root.py | 3 +- wsmeext/tests/test_sqlalchemy_controllers.py | 16 ++++- 7 files changed, 140 insertions(+), 16 deletions(-) diff --git a/wsme/protocol.py b/wsme/protocol.py index d80ae1b..ec40093 100644 --- a/wsme/protocol.py +++ b/wsme/protocol.py @@ -2,6 +2,9 @@ import weakref import pkg_resources +from wsme.exc import ClientSideError + + __all__ = [ 'CallContext', @@ -111,3 +114,35 @@ def getprotocol(name, **options): raise ValueError("Cannot find protocol '%s'" % name) registered_protocols[name] = protocol_class return protocol_class(**options) + + +def media_type_accept(request, content_types): + """Return True if the requested media type is available. + + When request.method is GET or HEAD compare with the Accept header. + When request.method is POST, PUT or PATCH compare with the Content-Type + header. + When request.method is DELETE media type is irrelevant, so return True. + """ + if request.method in ['GET', 'HEAD']: + if request.accept: + if request.accept.best_match(content_types): + return True + error_message = ('Unacceptable Accept type: %s not in %s' + % (request.accept, content_types)) + raise ClientSideError(error_message, status_code=406) + return False + elif request.method in ['PUT', 'POST', 'PATCH']: + content_type = request.headers.get('Content-Type') + if content_type: + for ct in content_types: + if request.headers.get('Content-Type', '').startswith(ct): + return True + error_message = ('Unacceptable Content-Type: %s not in %s' + % (content_type, content_types)) + raise ClientSideError(error_message, status_code=415) + else: + raise ClientSideError('missing Content-Type header') + elif request.method in ['DELETE']: + return True + return False diff --git a/wsme/rest/protocol.py b/wsme/rest/protocol.py index 85c76fd..5201ccf 100644 --- a/wsme/rest/protocol.py +++ b/wsme/rest/protocol.py @@ -2,7 +2,7 @@ import os.path import logging from wsme.utils import OrderedDict -from wsme.protocol import CallContext, Protocol +from wsme.protocol import CallContext, Protocol, media_type_accept import wsme.rest import wsme.rest.args @@ -34,12 +34,7 @@ class RestProtocol(Protocol): for dataformat in self.dataformats: if request.path.endswith('.' + dataformat): return True - if request.headers.get('Accept') in self.content_types: - return True - for ct in self.content_types: - if request.headers['Content-Type'].startswith(ct): - return True - return False + return media_type_accept(request, self.content_types) def iter_calls(self, request): context = CallContext(request) diff --git a/wsme/root.py b/wsme/root.py index fcc2442..4b56c4f 100644 --- a/wsme/root.py +++ b/wsme/root.py @@ -230,20 +230,34 @@ class WSRoot(object): try: msg = None + error_status = 500 protocol = self._select_protocol(request) + if protocol is None: + if request.method in ['GET', 'HEAD']: + error_status = 406 + elif request.method in ['POST', 'PUT', 'PATCH']: + error_status = 415 + except ClientSideError as e: + error_status = e.code + msg = e.faultstring + protocol = None except Exception as e: - msg = ("Error while selecting protocol: %s" % str(e)) + msg = ("Unexpected error while selecting protocol: %s" % str(e)) log.exception(msg) protocol = None + error_status = 500 if protocol is None: if msg is None: msg = ("None of the following protocols can handle this " "request : %s" % ','.join(( p.name for p in self.protocols))) - res.status = 500 + res.status = error_status res.content_type = 'text/plain' - res.text = u(msg) + try: + res.text = u(msg) + except TypeError: + res.text = msg log.error(msg) return res diff --git a/wsme/tests/test_api.py b/wsme/tests/test_api.py index 9149cf9..907ce7b 100644 --- a/wsme/tests/test_api.py +++ b/wsme/tests/test_api.py @@ -200,7 +200,7 @@ Value should be one of:")) app = webtest.TestApp(r.wsgiapp()) res = app.get('/', expect_errors=True) - assert res.status_int == 500 + assert res.status_int == 406 print(res.body) assert res.body.find( b("None of the following protocols can handle this request")) != -1 diff --git a/wsme/tests/test_restjson.py b/wsme/tests/test_restjson.py index 573661e..0ee86e2 100644 --- a/wsme/tests/test_restjson.py +++ b/wsme/tests/test_restjson.py @@ -351,7 +351,7 @@ class TestRestJson(wsme.tests.protocol.RestOnlyProtocolTestCase): def test_GET(self): headers = { - 'Content-Type': 'application/json', + 'Accept': 'application/json', } res = self.app.get( '/crud?ref.id=1', @@ -362,7 +362,58 @@ class TestRestJson(wsme.tests.protocol.RestOnlyProtocolTestCase): print(result) assert result['data']['id'] == 1 assert result['data']['name'] == u("test") - assert result['message'] == "read" + + def test_GET_complex_accept(self): + headers = { + 'Accept': 'text/html,application/xml;q=0.9,*/*;q=0.8' + } + res = self.app.get( + '/crud?ref.id=1', + headers=headers, + expect_errors=False) + print("Received:", res.body) + result = json.loads(res.text) + print(result) + assert result['data']['id'] == 1 + assert result['data']['name'] == u("test") + + def test_GET_complex_choose_xml(self): + headers = { + 'Accept': 'text/html,text/xml;q=0.9,*/*;q=0.8' + } + res = self.app.get( + '/crud?ref.id=1', + headers=headers, + expect_errors=False) + print("Received:", res.body) + assert res.content_type == 'text/xml' + + def test_GET_complex_accept_no_match(self): + headers = { + 'Accept': 'text/html,application/xml;q=0.9' + } + res = self.app.get( + '/crud?ref.id=1', + headers=headers, + status=406) + print("Received:", res.body) + assert res.body == ("Unacceptable Accept type: " + "text/html, application/xml;q=0.9 not in " + "['application/json', 'text/javascript', " + "'application/javascript', 'text/xml']") + + def test_GET_bad_simple_accept(self): + headers = { + 'Accept': 'text/plain', + } + res = self.app.get( + '/crud?ref.id=1', + headers=headers, + status=406) + print("Received:", res.body) + assert res.body == ("Unacceptable Accept type: text/plain not in " + "['application/json', 'text/javascript', " + "'application/javascript', 'text/xml']") def test_POST(self): headers = { @@ -380,6 +431,20 @@ class TestRestJson(wsme.tests.protocol.RestOnlyProtocolTestCase): assert result['data']['name'] == u("test") assert result['message'] == "update" + def test_POST_bad_content_type(self): + headers = { + 'Content-Type': 'text/plain', + } + res = self.app.post( + '/crud', + json.dumps(dict(data=dict(id=1, name=u('test')))), + headers=headers, + status=415) + print("Received:", res.body) + assert res.body == ("Unacceptable Content-Type: text/plain not in " + "['application/json', 'text/javascript', " + "'application/javascript', 'text/xml']") + def test_DELETE(self): res = self.app.delete( '/crud.json?ref.id=1', @@ -393,7 +458,7 @@ class TestRestJson(wsme.tests.protocol.RestOnlyProtocolTestCase): def test_extra_arguments(self): headers = { - 'Content-Type': 'application/json', + 'Accept': 'application/json', } res = self.app.get( '/crud?ref.id=1&extraarg=foo', diff --git a/wsme/tests/test_root.py b/wsme/tests/test_root.py index d0a3485..4c68375 100644 --- a/wsme/tests/test_root.py +++ b/wsme/tests/test_root.py @@ -38,4 +38,5 @@ class TestRoot(unittest.TestCase): res = root._handle_request(req) assert res.status_int == 500 assert res.content_type == 'text/plain' - assert res.text == u('Error while selecting protocol: test'), req.text + assert (res.text == + 'Unexpected error while selecting protocol: test'), req.text diff --git a/wsmeext/tests/test_sqlalchemy_controllers.py b/wsmeext/tests/test_sqlalchemy_controllers.py index 2bcf633..1956788 100644 --- a/wsmeext/tests/test_sqlalchemy_controllers.py +++ b/wsmeext/tests/test_sqlalchemy_controllers.py @@ -136,12 +136,26 @@ class TestCRUDController(): DBSession.flush() pid = p.id r = self.app.get('/person?ref.id=%s' % pid, - headers={'Content-Type': 'application/json'}) + headers={'Accept': 'application/json'}) r = json.loads(r.text) print(r) assert r['name'] == u('Pierre-Joseph') assert r['birthdate'] == u('1809-01-15') + def test_GET_bad_accept(self): + p = DBPerson( + name=u('Pierre-Joseph'), + birthdate=datetime.date(1809, 1, 15)) + DBSession.add(p) + DBSession.flush() + pid = p.id + r = self.app.get('/person?ref.id=%s' % pid, + headers={'Accept': 'text/plain'}, + status=406) + assert r.text == ("Unacceptable Accept type: text/plain not in " + "['application/json', 'text/javascript', " + "'application/javascript', 'text/xml']") + def test_update(self): p = DBPerson( name=u('Pierre-Joseph'),