Change the scan_api interface. It now yield the original function and static args (typically the 'self' attribute). Thanks to that the lookup_function method of WSRoot can access functions outside the WSRoot.
This commit is contained in:
parent
15417d2ac0
commit
ade5325e13
@ -56,22 +56,23 @@ class validate(object):
|
||||
return func
|
||||
|
||||
|
||||
def scan_api(controller, path=[]):
|
||||
def scan_api(controller, path=[], objects=[]):
|
||||
"""
|
||||
Recursively iterate a controller api entries, while setting
|
||||
their :attr:`FunctionDefinition.path`.
|
||||
Recursively iterate a controller api entries.
|
||||
"""
|
||||
for name in dir(controller):
|
||||
if name.startswith('_'):
|
||||
continue
|
||||
a = getattr(controller, name)
|
||||
if a in objects:
|
||||
continue
|
||||
if inspect.ismethod(a):
|
||||
if wsme.api.iswsmefunction(a):
|
||||
yield path + [name], a._wsme_definition
|
||||
yield path + [name], a, []
|
||||
elif inspect.isclass(a):
|
||||
continue
|
||||
else:
|
||||
if len(path) > APIPATH_MAXLEN:
|
||||
raise ValueError("Path is too long: " + str(path))
|
||||
for i in scan_api(a, path + [name]):
|
||||
for i in scan_api(a, path + [name], objects + [a]):
|
||||
yield i
|
||||
|
42
wsme/root.py
42
wsme/root.py
@ -125,10 +125,16 @@ class WSRoot(object):
|
||||
:rtype: list of (path, :class:`FunctionDefinition`)
|
||||
"""
|
||||
if self._api is None:
|
||||
self._api = list(self._scan_api(self))
|
||||
for path, fdef in self._api:
|
||||
self._api = [
|
||||
(path, f, f._wsme_definition, args)
|
||||
for path, f, args in self._scan_api(self)
|
||||
]
|
||||
for path, f, fdef, args in self._api:
|
||||
fdef.resolve_types(self.__registry__)
|
||||
return self._api
|
||||
return [
|
||||
(path, fdef)
|
||||
for path, f, fdef, args in self._api
|
||||
]
|
||||
|
||||
def _get_protocol(self, name):
|
||||
for protocol in self.protocols:
|
||||
@ -171,8 +177,10 @@ class WSRoot(object):
|
||||
'The %s protocol was unable to extract a function '
|
||||
'path from the request') % protocol.name)
|
||||
|
||||
context.func, context.funcdef = self._lookup_function(context.path)
|
||||
context.func, context.funcdef, args = \
|
||||
self._lookup_function(context.path)
|
||||
kw = protocol.read_arguments(context)
|
||||
args = list(args)
|
||||
|
||||
for arg in context.funcdef.arguments:
|
||||
if arg.mandatory and arg.name not in kw:
|
||||
@ -180,7 +188,7 @@ class WSRoot(object):
|
||||
|
||||
txn = self.begin()
|
||||
try:
|
||||
result = context.func(**kw)
|
||||
result = context.func(*args, **kw)
|
||||
txn.commit()
|
||||
except:
|
||||
txn.abort()
|
||||
@ -309,25 +317,13 @@ class WSRoot(object):
|
||||
return res
|
||||
|
||||
def _lookup_function(self, path):
|
||||
a = self
|
||||
if not self._api:
|
||||
self.getapi()
|
||||
|
||||
isprotocol_specific = path[0] == '_protocol'
|
||||
|
||||
if isprotocol_specific:
|
||||
a = self._get_protocol(path[1])
|
||||
path = path[2:]
|
||||
|
||||
for name in path:
|
||||
a = getattr(a, name, None)
|
||||
if a is None:
|
||||
break
|
||||
|
||||
if not hasattr(a, '_wsme_definition'):
|
||||
raise UnknownFunction('/'.join(path))
|
||||
|
||||
definition = a._wsme_definition
|
||||
|
||||
return a, definition
|
||||
for fpath, f, fdef, args in self._api:
|
||||
if path == fpath:
|
||||
return f, fdef, args
|
||||
raise UnknownFunction('/'.join(path))
|
||||
|
||||
def _format_exception(self, excinfo):
|
||||
"""Extract informations that can be sent to the client."""
|
||||
|
@ -74,9 +74,10 @@ class TestController(unittest.TestCase):
|
||||
|
||||
api = list(scan_api(r))
|
||||
assert len(api) == 1
|
||||
path, fd = api[0]
|
||||
path, fd, args = api[0]
|
||||
assert path == ['ns', 'multiply']
|
||||
assert fd.name == 'multiply'
|
||||
assert fd._wsme_definition.name == 'multiply'
|
||||
assert args == []
|
||||
|
||||
def test_scan_subclass(self):
|
||||
class MyRoot(WSRoot):
|
||||
@ -90,11 +91,16 @@ class TestController(unittest.TestCase):
|
||||
|
||||
def test_scan_api_too_deep(self):
|
||||
class Loop(object):
|
||||
loop = None
|
||||
Loop.me = Loop()
|
||||
pass
|
||||
|
||||
l = Loop()
|
||||
for i in range(0, 21):
|
||||
nl = Loop()
|
||||
nl.l = l
|
||||
l = nl
|
||||
|
||||
class MyRoot(WSRoot):
|
||||
loop = Loop()
|
||||
loop = l
|
||||
|
||||
r = MyRoot()
|
||||
|
||||
@ -179,14 +185,6 @@ class TestController(unittest.TestCase):
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def test_getapi(self):
|
||||
class MyRoot(WSRoot):
|
||||
pass
|
||||
|
||||
r = MyRoot()
|
||||
api = r.getapi()
|
||||
assert r.getapi() is api
|
||||
|
||||
|
||||
class TestFunctionDefinition(unittest.TestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user