Multiple protocol accept or content-type matching
The changes in 8710dabb652dae775dee31789e91608f832e62e6 broke protocol failover when the REST protocol is listed before others (see bug referenced below). This patch tries to solve both issues by trying to match accept over all the protocols, only giving up a 406 or 415 if all protocols fail, using the last failure as the error message. Related-Bug: #1419110 Closes-Bug: #1442710 Change-Id: I328a392151013c46207519c245213d5dec48ecc9
This commit is contained in:
parent
f66cf4c3cc
commit
e31045e57a
@ -117,7 +117,7 @@ def getprotocol(name, **options):
|
||||
|
||||
|
||||
def media_type_accept(request, content_types):
|
||||
"""Return True if the requested media type is available.
|
||||
"""Validate media types against request.method.
|
||||
|
||||
When request.method is GET or HEAD compare with the Accept header.
|
||||
When request.method is POST, PUT or PATCH compare with the Content-Type
|
||||
@ -131,7 +131,6 @@ def media_type_accept(request, content_types):
|
||||
error_message = ('Unacceptable Accept type: %s not in %s'
|
||||
% (request.accept, content_types))
|
||||
raise ClientSideError(error_message, status_code=406)
|
||||
return False
|
||||
elif request.method in ['PUT', 'POST', 'PATCH']:
|
||||
content_type = request.headers.get('Content-Type')
|
||||
if content_type:
|
||||
|
25
wsme/root.py
25
wsme/root.py
@ -149,6 +149,7 @@ class WSRoot(object):
|
||||
request.body[:512] or
|
||||
request.body) or '')
|
||||
protocol = None
|
||||
error = ClientSideError(status_code=406)
|
||||
path = str(request.path)
|
||||
assert path.startswith(self._webpath)
|
||||
path = path[len(self._webpath) + 1:]
|
||||
@ -157,9 +158,16 @@ class WSRoot(object):
|
||||
else:
|
||||
|
||||
for p in self.protocols:
|
||||
if p.accept(request):
|
||||
protocol = p
|
||||
break
|
||||
try:
|
||||
if p.accept(request):
|
||||
protocol = p
|
||||
break
|
||||
except ClientSideError as e:
|
||||
error = e
|
||||
# If we could not select a protocol, we raise the last exception
|
||||
# that we got, or the default one.
|
||||
if not protocol:
|
||||
raise error
|
||||
return protocol
|
||||
|
||||
def _do_call(self, protocol, context):
|
||||
@ -232,11 +240,6 @@ class WSRoot(object):
|
||||
msg = None
|
||||
error_status = 500
|
||||
protocol = self._select_protocol(request)
|
||||
if protocol is None:
|
||||
if request.method in ['GET', 'HEAD']:
|
||||
error_status = 406
|
||||
elif request.method in ['POST', 'PUT', 'PATCH']:
|
||||
error_status = 415
|
||||
except ClientSideError as e:
|
||||
error_status = e.code
|
||||
msg = e.faultstring
|
||||
@ -248,7 +251,7 @@ class WSRoot(object):
|
||||
error_status = 500
|
||||
|
||||
if protocol is None:
|
||||
if msg is None:
|
||||
if not msg:
|
||||
msg = ("None of the following protocols can handle this "
|
||||
"request : %s" % ','.join((
|
||||
p.name for p in self.protocols)))
|
||||
@ -296,6 +299,10 @@ class WSRoot(object):
|
||||
else:
|
||||
res.status = protocol.get_response_status(request)
|
||||
res_content_type = protocol.get_response_contenttype(request)
|
||||
except ClientSideError as e:
|
||||
request.server_errorcount += 1
|
||||
res.status = e.code
|
||||
res.text = e.faultstring
|
||||
except Exception:
|
||||
infos = wsme.api.format_exception(sys.exc_info(), self._debug)
|
||||
request.server_errorcount += 1
|
||||
|
@ -201,6 +201,7 @@ Value should be one of:"))
|
||||
app = webtest.TestApp(r.wsgiapp())
|
||||
|
||||
res = app.get('/', expect_errors=True)
|
||||
print(res.status_int)
|
||||
assert res.status_int == 406
|
||||
print(res.body)
|
||||
assert res.body.find(
|
||||
|
@ -3,9 +3,12 @@
|
||||
import unittest
|
||||
|
||||
from wsme import WSRoot
|
||||
import wsme.protocol
|
||||
import wsme.rest.protocol
|
||||
from wsme.root import default_prepare_response_body
|
||||
|
||||
from six import b, u
|
||||
from webob import Request
|
||||
|
||||
|
||||
class TestRoot(unittest.TestCase):
|
||||
@ -24,9 +27,9 @@ class TestRoot(unittest.TestCase):
|
||||
default_prepare_response_body(None, [u('a'), u('b')]) == u('a\nb')
|
||||
|
||||
def test_protocol_selection_error(self):
|
||||
import wsme.protocol
|
||||
|
||||
class P(wsme.protocol.Protocol):
|
||||
name = "test"
|
||||
|
||||
def accept(self, r):
|
||||
raise Exception('test')
|
||||
|
||||
@ -40,3 +43,80 @@ class TestRoot(unittest.TestCase):
|
||||
assert res.content_type == 'text/plain'
|
||||
assert (res.text ==
|
||||
'Unexpected error while selecting protocol: test'), req.text
|
||||
|
||||
def test_protocol_selection_accept_mismatch(self):
|
||||
"""Verify that we get a 406 error on wrong Accept header."""
|
||||
class P(wsme.protocol.Protocol):
|
||||
name = "test"
|
||||
|
||||
def accept(self, r):
|
||||
return False
|
||||
|
||||
root = WSRoot()
|
||||
root.addprotocol(wsme.rest.protocol.RestProtocol())
|
||||
root.addprotocol(P())
|
||||
|
||||
req = Request.blank('/test?check=a&check=b&name=Bob')
|
||||
req.method = 'GET'
|
||||
res = root._handle_request(req)
|
||||
assert res.status_int == 406
|
||||
assert res.content_type == 'text/plain'
|
||||
assert res.text.startswith(
|
||||
'None of the following protocols can handle this request'
|
||||
), req.text
|
||||
|
||||
def test_protocol_selection_content_type_mismatch(self):
|
||||
"""Verify that we get a 415 error on wrong Content-Type header."""
|
||||
class P(wsme.protocol.Protocol):
|
||||
name = "test"
|
||||
|
||||
def accept(self, r):
|
||||
return False
|
||||
|
||||
root = WSRoot()
|
||||
root.addprotocol(wsme.rest.protocol.RestProtocol())
|
||||
root.addprotocol(P())
|
||||
|
||||
req = Request.blank('/test?check=a&check=b&name=Bob')
|
||||
req.method = 'POST'
|
||||
req.headers['Content-Type'] = "test/unsupported"
|
||||
res = root._handle_request(req)
|
||||
assert res.status_int == 415
|
||||
assert res.content_type == 'text/plain'
|
||||
assert res.text.startswith(
|
||||
'Unacceptable Content-Type: test/unsupported not in'
|
||||
), req.text
|
||||
|
||||
def test_protocol_selection_get_method(self):
|
||||
class P(wsme.protocol.Protocol):
|
||||
name = "test"
|
||||
|
||||
def accept(self, r):
|
||||
return True
|
||||
|
||||
root = WSRoot()
|
||||
root.addprotocol(wsme.rest.protocol.RestProtocol())
|
||||
root.addprotocol(P())
|
||||
|
||||
req = Request.blank('/test?check=a&check=b&name=Bob')
|
||||
req.method = 'GET'
|
||||
req.headers['Accept'] = 'test/fake'
|
||||
p = root._select_protocol(req)
|
||||
assert p.name == "test"
|
||||
|
||||
def test_protocol_selection_post_method(self):
|
||||
class P(wsme.protocol.Protocol):
|
||||
name = "test"
|
||||
|
||||
def accept(self, r):
|
||||
return True
|
||||
|
||||
root = WSRoot()
|
||||
root.addprotocol(wsme.rest.protocol.RestProtocol())
|
||||
root.addprotocol(P())
|
||||
|
||||
req = Request.blank('/test?check=a&check=b&name=Bob')
|
||||
req.headers['Content-Type'] = 'test/fake'
|
||||
req.method = 'POST'
|
||||
p = root._select_protocol(req)
|
||||
assert p.name == "test"
|
||||
|
Loading…
x
Reference in New Issue
Block a user