diff --git a/wsme/rest/__init__.py b/wsme/rest/__init__.py index 0968770..3153a49 100644 --- a/wsme/rest/__init__.py +++ b/wsme/rest/__init__.py @@ -56,22 +56,23 @@ class validate(object): return func -def scan_api(controller, path=[]): +def scan_api(controller, path=[], objects=[]): """ - Recursively iterate a controller api entries, while setting - their :attr:`FunctionDefinition.path`. + Recursively iterate a controller api entries. """ for name in dir(controller): if name.startswith('_'): continue a = getattr(controller, name) + if a in objects: + continue if inspect.ismethod(a): if wsme.api.iswsmefunction(a): - yield path + [name], a._wsme_definition + yield path + [name], a, [] elif inspect.isclass(a): continue else: if len(path) > APIPATH_MAXLEN: raise ValueError("Path is too long: " + str(path)) - for i in scan_api(a, path + [name]): + for i in scan_api(a, path + [name], objects + [a]): yield i diff --git a/wsme/root.py b/wsme/root.py index e6b4e96..e19d7c0 100644 --- a/wsme/root.py +++ b/wsme/root.py @@ -125,10 +125,16 @@ class WSRoot(object): :rtype: list of (path, :class:`FunctionDefinition`) """ if self._api is None: - self._api = list(self._scan_api(self)) - for path, fdef in self._api: + self._api = [ + (path, f, f._wsme_definition, args) + for path, f, args in self._scan_api(self) + ] + for path, f, fdef, args in self._api: fdef.resolve_types(self.__registry__) - return self._api + return [ + (path, fdef) + for path, f, fdef, args in self._api + ] def _get_protocol(self, name): for protocol in self.protocols: @@ -171,8 +177,10 @@ class WSRoot(object): 'The %s protocol was unable to extract a function ' 'path from the request') % protocol.name) - context.func, context.funcdef = self._lookup_function(context.path) + context.func, context.funcdef, args = \ + self._lookup_function(context.path) kw = protocol.read_arguments(context) + args = list(args) for arg in context.funcdef.arguments: if arg.mandatory and arg.name not in kw: @@ -180,7 +188,7 @@ class WSRoot(object): txn = self.begin() try: - result = context.func(**kw) + result = context.func(*args, **kw) txn.commit() except: txn.abort() @@ -309,25 +317,13 @@ class WSRoot(object): return res def _lookup_function(self, path): - a = self + if not self._api: + self.getapi() - isprotocol_specific = path[0] == '_protocol' - - if isprotocol_specific: - a = self._get_protocol(path[1]) - path = path[2:] - - for name in path: - a = getattr(a, name, None) - if a is None: - break - - if not hasattr(a, '_wsme_definition'): - raise UnknownFunction('/'.join(path)) - - definition = a._wsme_definition - - return a, definition + for fpath, f, fdef, args in self._api: + if path == fpath: + return f, fdef, args + raise UnknownFunction('/'.join(path)) def _format_exception(self, excinfo): """Extract informations that can be sent to the client.""" diff --git a/wsme/tests/test_api.py b/wsme/tests/test_api.py index 75a70a2..2231edf 100644 --- a/wsme/tests/test_api.py +++ b/wsme/tests/test_api.py @@ -74,9 +74,10 @@ class TestController(unittest.TestCase): api = list(scan_api(r)) assert len(api) == 1 - path, fd = api[0] + path, fd, args = api[0] assert path == ['ns', 'multiply'] - assert fd.name == 'multiply' + assert fd._wsme_definition.name == 'multiply' + assert args == [] def test_scan_subclass(self): class MyRoot(WSRoot): @@ -90,11 +91,16 @@ class TestController(unittest.TestCase): def test_scan_api_too_deep(self): class Loop(object): - loop = None - Loop.me = Loop() + pass + + l = Loop() + for i in range(0, 21): + nl = Loop() + nl.l = l + l = nl class MyRoot(WSRoot): - loop = Loop() + loop = l r = MyRoot() @@ -179,14 +185,6 @@ class TestController(unittest.TestCase): except ValueError: pass - def test_getapi(self): - class MyRoot(WSRoot): - pass - - r = MyRoot() - api = r.getapi() - assert r.getapi() is api - class TestFunctionDefinition(unittest.TestCase):