diff --git a/setup.cfg b/setup.cfg index 93e1a3e..a1a15b8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,8 +30,9 @@ classifier = [entry_points] wsme.protocols = - restjson = wsme.rest.json:RestJsonProtocol - restxml = wsme.rest.xml:RestXmlProtocol + rest = wsme.rest.protocol:RestProtocol + restjson = wsme.rest.protocol:RestProtocol + restxml = wsme.rest.protocol:RestProtocol [files] packages = diff --git a/wsme/protocol.py b/wsme/protocol.py index 6494224..d80ae1b 100644 --- a/wsme/protocol.py +++ b/wsme/protocol.py @@ -55,7 +55,6 @@ class ObjectDict(object): class Protocol(object): name = None displayname = None - dataformat = None content_types = [] def resolve_path(self, path): @@ -73,8 +72,6 @@ class Protocol(object): yield self.resolve_path(path), attr def accept(self, request): - if request.path.endswith('.' + self.dataformat): - return True return request.headers.get('Content-Type') in self.content_types def iter_calls(self, request): diff --git a/wsme/rest/args.py b/wsme/rest/args.py index c7ac592..5bd337c 100644 --- a/wsme/rest/args.py +++ b/wsme/rest/args.py @@ -144,39 +144,32 @@ def args_from_body(funcdef, body, mimetype): from wsme.rest import json as restjson from wsme.rest import xml as restxml - kw = {} - if funcdef.body_type is not None: - bodydata = None - if mimetype in restjson.RestJsonProtocol.content_types: - if hasattr(body, 'read'): - jsonbody = restjson.json.load(body) - else: - jsonbody = restjson.json.loads(body) - bodydata = restjson.fromjson(funcdef.body_type, jsonbody) - elif mimetype in restxml.RestXmlProtocol.content_types: - if hasattr(body, 'read'): - xmlbody = restxml.et.parse(body) - else: - xmlbody = restxml.et.fromstring(body) - bodydata = restxml.fromxml(funcdef.body_type, xmlbody) - if bodydata: - kw[funcdef.arguments[-1].name] = bodydata + datatypes = {funcdef.arguments[-1].name: funcdef.body_type} + else: + datatypes = dict(((a.name, a.datatype) for a in funcdef.arguments)) + + if mimetype in restjson.accept_content_types: + dataformat = restjson + elif mimetype in restxml.accept_content_types: + dataformat = restxml + else: + raise ValueError("Unknow mimetype: %s" % mimetype) + + kw = dataformat.parse( + body, datatypes, bodyarg=funcdef.body_type is not None + ) return (), kw def combine_args(funcdef, *akw): newargs, newkwargs = [], {} - argindexes = {} - for i, arg in enumerate(funcdef.arguments): - argindexes[arg.name] = i - newargs.append(arg.default) for args, kwargs in akw: for i, arg in enumerate(args): - newargs[i] = arg + newkwargs[funcdef.arguments[i].name] = arg for name, value in kwargs.iteritems(): - newargs[argindexes[name]] = value + newkwargs[name] = value return newargs, newkwargs diff --git a/wsme/rest/json.py b/wsme/rest/json.py index 7fa0654..898c5a0 100644 --- a/wsme/rest/json.py +++ b/wsme/rest/json.py @@ -9,7 +9,6 @@ import six from simplegeneric import generic -from wsme.rest.protocol import RestProtocol from wsme.types import Unset import wsme.types @@ -19,6 +18,14 @@ except ImportError: import json # noqa +content_type = 'application/json' +accept_content_types = [ + content_type, + 'text/javascript', + 'application/javascript' +] + + @generic def tojson(datatype, value): """ @@ -184,7 +191,7 @@ def datetime_fromjson(datatype, value): return datetime.datetime.strptime(value, '%Y-%m-%dT%H:%M:%S') -class RestJsonProtocol(RestProtocol): +class RestJson(object): """ REST+Json protocol. @@ -193,18 +200,12 @@ class RestJsonProtocol(RestProtocol): .. autoattribute:: content_types """ - name = 'restjson' - displayname = 'REST+Json' - dataformat = 'json' - content_types = [ - 'application/json', - 'application/javascript', - 'text/javascript', - ''] + name = 'json' + content_type = 'application/json' - def __init__(self, nest_result=False): - super(RestJsonProtocol, self).__init__() - self.nest_result = nest_result + #def __init__(self, nest_result=False): + # super(RestJsonProtocol, self).__init__() + # self.nest_result = nest_result def decode_arg(self, value, arg): return fromjson(arg.datatype, value) @@ -222,9 +223,6 @@ class RestJsonProtocol(RestProtocol): r = {'result': r} return json.dumps(r) - def encode_error(self, context, errordetail): - return json.dumps(errordetail) - def encode_sample_value(self, datatype, value, format=False): r = tojson(datatype, value) content = json.dumps(r, ensure_ascii=False, @@ -249,3 +247,34 @@ class RestJsonProtocol(RestProtocol): indent=4 if format else 0, sort_keys=format) return ('javascript', content) + + +def get_format(): + return RestJson() + + +def parse(s, datatypes, bodyarg): + if hasattr(s, 'read'): + jdata = json.load(s) + else: + jdata = json.loads(s) + if bodyarg: + argname = list(datatypes.keys())[0] + kw = {argname: fromjson(datatypes[argname], jdata)} + else: + kw = {} + for key, datatype in datatypes.items(): + if key in jdata: + kw[key] = fromjson(datatype, jdata[key]) + return kw + + +def tostring(value, datatype, attrname=None): + jsondata = tojson(datatype, value) + if attrname is not None: + jsondata = {attrname: jsondata} + return json.dumps(tojson(datatype, value)) + + +def encode_error(context, errordetail): + return json.dumps(errordetail) diff --git a/wsme/rest/protocol.py b/wsme/rest/protocol.py index bdcd025..2fe6c39 100644 --- a/wsme/rest/protocol.py +++ b/wsme/rest/protocol.py @@ -1,19 +1,66 @@ +import collections +import os.path import logging import six -from six import u - -from wsme.exc import ClientSideError, UnknownArgument +from wsme.exc import ClientSideError, UnknownArgument, MissingArgument from wsme.protocol import CallContext, Protocol -from wsme.rest.args import from_params -from wsme.types import Unset + +import wsme.rest +import wsme.rest.args log = logging.getLogger(__name__) class RestProtocol(Protocol): + name = 'rest' + displayname = 'REST' + dataformats = ['json', 'xml'] + content_types = ['application/json', 'text/xml'] + + def __init__(self, dataformats=None): + if dataformats is None: + dataformats = RestProtocol.dataformats + + self.dataformats = collections.OrderedDict() + self.content_types = [] + + for dataformat in dataformats: + __import__('wsme.rest.' + dataformat) + dfmod = getattr(wsme.rest, dataformat) + self.dataformats[dataformat] = dfmod + self.content_types.extend(dfmod.accept_content_types) + + def accept(self, request): + for dataformat in self.dataformats: + if request.path.endswith('.' + dataformat): + return True + return request.headers.get('Content-Type') in self.content_types + def iter_calls(self, request): - yield CallContext(request) + context = CallContext(request) + context.outformat = None + ext = os.path.splitext(request.path.split('/')[-1])[1] + inmime = request.content_type + outmime = request.accept.best_match(self.content_types) + + outformat = None + for dfname, df in self.dataformats.items(): + if ext == '.' + dfname: + outformat = df + + if outformat is None and request.accept: + for dfname, df in self.dataformats.items(): + if outmime in df.accept_content_types: + outformat = df + + if outformat is None: + for dfname, df in self.dataformats.items(): + if inmime == df.content_type: + outformat = df + + context.outformat = outformat + yield context def extract_path(self, context): path = context.request.path @@ -21,8 +68,9 @@ class RestProtocol(Protocol): path = path[len(self.root._webpath):] path = path.strip('/').split('/') - if path[-1].endswith('.' + self.dataformat): - path[-1] = path[-1][:-len(self.dataformat) - 1] + for dataformat in self.dataformats: + if path[-1].endswith('.' + dataformat): + path[-1] = path[-1][:-len(dataformat) - 1] # Check if the path is actually a function, and if not # see if the http method make a difference @@ -60,42 +108,55 @@ class RestProtocol(Protocol): raise ClientSideError( "Cannot read parameters from both a body and GET/POST params") + param_args = (), {} + body = None + if 'body' in request.params: body = request.params['body'] + body_mimetype = context.outformat.content_type + if body is None: + body = request.body + body_mimetype = request.content_type + param_args = wsme.rest.args.args_from_params( + funcdef, request.params + ) + if isinstance(body, six.binary_type): + body = body.decode('utf8') - if body is None and len(request.params): - kw = {} - hit_paths = set() - for argdef in funcdef.arguments: - value = from_params( - argdef.datatype, request.params, argdef.name, hit_paths) - if value is not Unset: - kw[argdef.name] = value - paths = set(request.params.keys()) - unknown_paths = paths - hit_paths - if unknown_paths: - raise UnknownArgument(', '.join(unknown_paths)) - return kw + if body and body_mimetype in self.content_types: + body_args = wsme.rest.args.args_from_body( + funcdef, body, body_mimetype + ) else: - if body is None: - body = request.body - if isinstance(body, six.binary_type): - body = body.decode('utf8') - if body: - parsed_args = self.parse_args(body) - else: - parsed_args = {} + body_args = ((), {}) - kw = {} + args, kw = wsme.rest.args.combine_args( + funcdef, + param_args, + body_args + ) - for arg in funcdef.arguments: - if arg.name not in parsed_args: - continue + for a in funcdef.arguments: + if a.mandatory and a.name not in kw: + raise MissingArgument(a.name) - value = parsed_args.pop(arg.name) - kw[arg.name] = self.decode_arg(value, arg) + argnames = set((a.name for a in funcdef.arguments)) + + for k in kw: + if k not in argnames: + raise UnknownArgument(k) - if parsed_args: - raise UnknownArgument(u(', ').join(parsed_args.keys())) return kw + + def encode_result(self, context, result): + out = context.outformat.tostring( + result, context.funcdef.return_type + ) + return out + + def encode_error(self, context, errordetail): + out = context.outformat.encode_error( + context, errordetail + ) + return out diff --git a/wsme/rest/xml.py b/wsme/rest/xml.py index c78eedf..00aa6f2 100644 --- a/wsme/rest/xml.py +++ b/wsme/rest/xml.py @@ -11,9 +11,15 @@ from simplegeneric import generic from wsme.rest.protocol import RestProtocol import wsme.types +from wsme.exc import UnknownArgument import re +content_type = 'text/xml' +accept_content_types = [ + content_type, +] + time_re = re.compile(r'(?P[0-2][0-9]):(?P[0-5][0-9]):(?P[0-6][0-9])') @@ -244,6 +250,11 @@ class RestXmlProtocol(RestProtocol): return et.tostring( toxml(context.funcdef.return_type, 'result', result)) + +class RestXml(object): + name = 'xml' + content_type = 'text/xml' + def encode_error(self, context, errordetail): el = et.Element('error') et.SubElement(el, 'faultcode').text = errordetail['faultcode'] @@ -274,3 +285,37 @@ class RestXmlProtocol(RestProtocol): xml_indent(r) content = et.tostring(r) return ('xml', content) + + +def get_format(): + return RestXml() + + +def parse(s, datatypes, bodyarg): + if hasattr(s, 'read'): + tree = et.parse(s) + else: + tree = et.fromstring(s) + if bodyarg: + name = list(datatypes.keys())[0] + return fromxml(datatypes[name], tree) + else: + kw = {} + for sub in tree: + if sub.tag not in datatypes: + raise UnknownArgument(sub.tag) + kw[sub.tag] = fromxml(datatypes[sub.tag], sub) + return kw + + +def tostring(value, datatype, attrname='result'): + return et.tostring(toxml(datatype, attrname, value)) + + +def encode_error(context, errordetail): + el = et.Element('error') + et.SubElement(el, 'faultcode').text = errordetail['faultcode'] + et.SubElement(el, 'faultstring').text = errordetail['faultstring'] + if 'debuginfo' in errordetail: + et.SubElement(el, 'debuginfo').text = errordetail['debuginfo'] + return et.tostring(el) diff --git a/wsme/spore.py b/wsme/spore.py index 9c4d1eb..8de0ed2 100644 --- a/wsme/spore.py +++ b/wsme/spore.py @@ -3,7 +3,7 @@ from wsme import types try: import simplejson as json except ImportError: - import json + import json # noqa def getdesc(root, host_url=''):