Change the scan_api interface. It now yield the original function and static args (typically the 'self' attribute). Thanks to that the lookup_function method of WSRoot can access functions outside the WSRoot.

This commit is contained in:
Christophe de Vienne 2012-11-26 01:25:51 +01:00
parent 15417d2ac0
commit ade5325e13
3 changed files with 36 additions and 41 deletions

View File

@ -56,22 +56,23 @@ class validate(object):
return func return func
def scan_api(controller, path=[]): def scan_api(controller, path=[], objects=[]):
""" """
Recursively iterate a controller api entries, while setting Recursively iterate a controller api entries.
their :attr:`FunctionDefinition.path`.
""" """
for name in dir(controller): for name in dir(controller):
if name.startswith('_'): if name.startswith('_'):
continue continue
a = getattr(controller, name) a = getattr(controller, name)
if a in objects:
continue
if inspect.ismethod(a): if inspect.ismethod(a):
if wsme.api.iswsmefunction(a): if wsme.api.iswsmefunction(a):
yield path + [name], a._wsme_definition yield path + [name], a, []
elif inspect.isclass(a): elif inspect.isclass(a):
continue continue
else: else:
if len(path) > APIPATH_MAXLEN: if len(path) > APIPATH_MAXLEN:
raise ValueError("Path is too long: " + str(path)) raise ValueError("Path is too long: " + str(path))
for i in scan_api(a, path + [name]): for i in scan_api(a, path + [name], objects + [a]):
yield i yield i

View File

@ -125,10 +125,16 @@ class WSRoot(object):
:rtype: list of (path, :class:`FunctionDefinition`) :rtype: list of (path, :class:`FunctionDefinition`)
""" """
if self._api is None: if self._api is None:
self._api = list(self._scan_api(self)) self._api = [
for path, fdef in self._api: (path, f, f._wsme_definition, args)
for path, f, args in self._scan_api(self)
]
for path, f, fdef, args in self._api:
fdef.resolve_types(self.__registry__) fdef.resolve_types(self.__registry__)
return self._api return [
(path, fdef)
for path, f, fdef, args in self._api
]
def _get_protocol(self, name): def _get_protocol(self, name):
for protocol in self.protocols: for protocol in self.protocols:
@ -171,8 +177,10 @@ class WSRoot(object):
'The %s protocol was unable to extract a function ' 'The %s protocol was unable to extract a function '
'path from the request') % protocol.name) 'path from the request') % protocol.name)
context.func, context.funcdef = self._lookup_function(context.path) context.func, context.funcdef, args = \
self._lookup_function(context.path)
kw = protocol.read_arguments(context) kw = protocol.read_arguments(context)
args = list(args)
for arg in context.funcdef.arguments: for arg in context.funcdef.arguments:
if arg.mandatory and arg.name not in kw: if arg.mandatory and arg.name not in kw:
@ -180,7 +188,7 @@ class WSRoot(object):
txn = self.begin() txn = self.begin()
try: try:
result = context.func(**kw) result = context.func(*args, **kw)
txn.commit() txn.commit()
except: except:
txn.abort() txn.abort()
@ -309,25 +317,13 @@ class WSRoot(object):
return res return res
def _lookup_function(self, path): def _lookup_function(self, path):
a = self if not self._api:
self.getapi()
isprotocol_specific = path[0] == '_protocol' for fpath, f, fdef, args in self._api:
if path == fpath:
if isprotocol_specific: return f, fdef, args
a = self._get_protocol(path[1]) raise UnknownFunction('/'.join(path))
path = path[2:]
for name in path:
a = getattr(a, name, None)
if a is None:
break
if not hasattr(a, '_wsme_definition'):
raise UnknownFunction('/'.join(path))
definition = a._wsme_definition
return a, definition
def _format_exception(self, excinfo): def _format_exception(self, excinfo):
"""Extract informations that can be sent to the client.""" """Extract informations that can be sent to the client."""

View File

@ -74,9 +74,10 @@ class TestController(unittest.TestCase):
api = list(scan_api(r)) api = list(scan_api(r))
assert len(api) == 1 assert len(api) == 1
path, fd = api[0] path, fd, args = api[0]
assert path == ['ns', 'multiply'] assert path == ['ns', 'multiply']
assert fd.name == 'multiply' assert fd._wsme_definition.name == 'multiply'
assert args == []
def test_scan_subclass(self): def test_scan_subclass(self):
class MyRoot(WSRoot): class MyRoot(WSRoot):
@ -90,11 +91,16 @@ class TestController(unittest.TestCase):
def test_scan_api_too_deep(self): def test_scan_api_too_deep(self):
class Loop(object): class Loop(object):
loop = None pass
Loop.me = Loop()
l = Loop()
for i in range(0, 21):
nl = Loop()
nl.l = l
l = nl
class MyRoot(WSRoot): class MyRoot(WSRoot):
loop = Loop() loop = l
r = MyRoot() r = MyRoot()
@ -179,14 +185,6 @@ class TestController(unittest.TestCase):
except ValueError: except ValueError:
pass pass
def test_getapi(self):
class MyRoot(WSRoot):
pass
r = MyRoot()
api = r.getapi()
assert r.getapi() is api
class TestFunctionDefinition(unittest.TestCase): class TestFunctionDefinition(unittest.TestCase):