diff --git a/wsmeext/sqlalchemy/__init__.py b/wsmeext/sqlalchemy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/wsmeext/sqlalchemy/controllers.py b/wsmeext/sqlalchemy/controllers.py new file mode 100644 index 0000000..c4742a7 --- /dev/null +++ b/wsmeext/sqlalchemy/controllers.py @@ -0,0 +1,95 @@ +from wsme.rest import expose, validate +import wsme.types + +from wsmeext.sqlalchemy.types import SQLAlchemyRegistry + + +class CRUDControllerMeta(type): + def __init__(cls, name, bases, dct): + if cls.__saclass__ is not None: + if cls.__registry__ is None: + cls.__registry__ = wsme.types.registry + if cls.__wstype__ is None: + cls.__wstype__ = cls.__registry__.resolve_type( + SQLAlchemyRegistry.get( + cls.__registry__).getdatatype(cls.__saclass__)) + + cls.create = expose( + cls.__wstype__, + method='PUT', + wrap=True + )(cls.create) + cls.create = validate(cls.__wstype__)(cls.create) + + cls.read = expose( + cls.__wstype__, + method='GET', + wrap=True + )(cls.read) + cls.read = validate(cls.__wstype__)(cls.read) + + cls.update = expose( + cls.__wstype__, + method='POST', + wrap=True + )(cls.update) + cls.update = validate(cls.__wstype__)(cls.update) + + cls.delete = expose( + method='DELETE', + wrap=True + )(cls.delete) + cls.delete = validate(cls.__wstype__)(cls.delete) + + super(CRUDControllerMeta, cls).__init__(name, bases, dct) + + +class CRUDControllerBase(object): + __registry__ = None + __saclass__ = None + __wstype__ = None + __dbsession__ = None + + def _create_one(self, data): + obj = self.__saclass__() + data.to_instance(obj) + self.__dbsession__.add(obj) + return obj + + def _get_one(self, ref): + q = self.__dbsession__.query(self.__saclass__) + q = q.filter(ref.get_ref_criterion()) + return q.one() + + def _update_one(self, data): + obj = self._get_one(data) + if obj is None: + raise ValueError("No match for data=%s" % data) + data.to_instance(obj) + return obj + + def _delete(self, ref): + obj = self._get_one(ref) + self.__dbsession__.delete(obj) + + def create(self, data): + obj = self._create_one(data) + self.__dbsession__.flush() + return self.__wstype__(obj) + + def read(self, ref): + obj = self._get_one(ref) + return self.__wstype__(obj) + + def update(self, data): + obj = self._update_one(data) + self.__dbsession__.flush() + return self.__wstype__(obj) + + def delete(self, ref): + self._delete(ref) + self.__dbsession__.flush() + return None + +CRUDController = CRUDControllerMeta( + 'CRUDController', (CRUDControllerBase,), {}) diff --git a/wsmeext/sqlalchemy/types.py b/wsmeext/sqlalchemy/types.py new file mode 100644 index 0000000..ed9bfe2 --- /dev/null +++ b/wsmeext/sqlalchemy/types.py @@ -0,0 +1,201 @@ +import datetime +import decimal +import logging + +import six + +from sqlalchemy.orm import class_mapper +from sqlalchemy.orm.properties import ColumnProperty, RelationProperty + +import sqlalchemy.types + +import wsme.types + +log = logging.getLogger(__name__) + + +class SQLAlchemyRegistry(object): + @classmethod + def get(cls, registry): + if not hasattr(registry, 'sqlalchemy'): + registry.sqlalchemy = SQLAlchemyRegistry() + return registry.sqlalchemy + + def __init__(self): + self.types = {} + self.satypeclasses = { + sqlalchemy.types.Integer: int, + sqlalchemy.types.Boolean: bool, + sqlalchemy.types.Float: float, + sqlalchemy.types.Numeric: decimal.Decimal, + sqlalchemy.types.Date: datetime.date, + sqlalchemy.types.Time: datetime.time, + sqlalchemy.types.DateTime: datetime.datetime, + sqlalchemy.types.String: wsme.types.text, + sqlalchemy.types.Unicode: wsme.types.text, + } + + def getdatatype(self, sadatatype): + if sadatatype.__class__ in self.satypeclasses: + return self.satypeclasses[sadatatype.__class__] + elif sadatatype in self.types: + return self.types[sadatatype] + else: + return sadatatype.__name__ + + +def register_saclass(registry, saclass, typename=None): + """Associate a webservice type name to a SQLAlchemy mapped class. + The default typename if the saclass name itself. + """ + if typename is None: + typename = saclass.__name__ + + SQLAlchemyRegistry.get(registry).types[saclass] = typename + + +class wsattr(wsme.types.wsattr): + def __init__(self, datatype, saproperty=None, **kw): + super(wsattr, self).__init__(datatype, **kw) + self.saname = saproperty.key + self.saproperty = saproperty + self.isrelation = isinstance(saproperty, RelationProperty) + + +def make_wsattr(registry, saproperty): + datatype = None + if isinstance(saproperty, ColumnProperty): + if len(saproperty.columns) > 1: + log.warning("Cannot handle multi-column ColumnProperty") + return None + datatype = SQLAlchemyRegistry.get(registry).getdatatype( + saproperty.columns[0].type) + elif isinstance(saproperty, RelationProperty): + other_saclass = saproperty.mapper.class_ + datatype = SQLAlchemyRegistry.get(registry).getdatatype( + other_saclass) + if saproperty.uselist: + datatype = [datatype] + else: + log.warning("Don't know how to handle %s attributes" % + saproperty.__class__) + + if datatype: + return wsattr(datatype, saproperty) + + +class BaseMeta(wsme.types.BaseMeta): + def __new__(cls, name, bases, dct): + if '__registry__' not in dct: + dct['__registry__'] = wsme.types.registry + return type.__new__(cls, name, bases, dct) + + def __init__(cls, name, bases, dct): + saclass = getattr(cls, '__saclass__', None) + if saclass: + mapper = class_mapper(saclass) + cls._pkey_attrs = [] + cls._ref_attrs = [] + for prop in mapper.iterate_properties: + key = prop.key + if hasattr(cls, key): + continue + if key.startswith('_'): + continue + attr = make_wsattr(cls.__registry__, prop) + if attr is not None: + setattr(cls, key, attr) + + if attr and isinstance(prop, ColumnProperty) and \ + prop.columns[0] in mapper.primary_key: + cls._pkey_attrs.append(attr) + cls._ref_attrs.append(attr) + + register_saclass(cls.__registry__, cls.__saclass__, cls.__name__) + super(BaseMeta, cls).__init__(name, bases, dct) + + +class Base(six.with_metaclass(BaseMeta, wsme.types.Base)): + def __init__(self, instance=None, keyonly=False, attrs=None, eagerload=[]): + if instance: + self.from_instance(instance, keyonly, attrs, eagerload) + + def from_instance(self, instance, keyonly=False, attrs=None, eagerload=[]): + if keyonly: + attrs = self._pkey_attrs + self._ref_attrs + for attr in self._wsme_attributes: + if not isinstance(attr, wsattr): + continue + if attrs and not attr.isrelation and not attr.name in attrs: + continue + if attr.isrelation and not attr.name in eagerload: + continue + value = getattr(instance, attr.saname) + if attr.isrelation: + attr_keyonly = attr.name not in eagerload + attr_attrs = None + attr_eagerload = [] + if not attr_keyonly: + attr_attrs = [ + aname[len(attr.name) + 1:] + for aname in attrs + if aname.startswith(attr.name + '.') + ] + attr_eagerload = [ + aname[len(attr.name) + 1:] + for aname in eagerload + if aname.startswith(attr.name + '.') + ] + if attr.saproperty.uselist: + value = [ + attr.datatype.item_type( + o, + keyonly=attr_keyonly, + attrs=attr_attrs, + eagerload=attr_eagerload + ) + for o in value + ] + else: + value = attr.datatype( + value, + keyonly=attr_keyonly, + attrs=attr_attrs, + eagerload=attr_eagerload + ) + attr.__set__(self, value) + + def to_instance(self, instance): + for attr in self._wsme_attributes: + if isinstance(attr, wsattr): + value = attr.__get__(self, self.__class__) + if value is not wsme.types.Unset: + setattr(instance, attr.saname, value) + + def get_ref_criterion(self): + """Returns a criterion that match a database object + having the pkey/ref attribute values of this webservice object""" + criterions = [] + for attr in self._pkey_attrs + self._ref_attrs: + value = attr.__get__(self, self.__class__) + if value is not wsme.types.Unset: + criterions.append(attr.saproperty == value) + + +def generate_types(*classes, **kw): + registry = kw.pop('registry', wsme.types.registry) + prefix = kw.pop('prefix', '') + postfix = kw.pop('postfix', '') + makename = kw.pop('makename', lambda s: prefix + s + postfix) + + newtypes = {} + for c in classes: + if isinstance(c, list): + newtypes.update(generate_types(c)) + else: + name = makename(c.__name__) + newtypes[name] = BaseMeta(name, (Base, ), { + '__saclass__': c, + '__registry__': registry + }) + return newtypes diff --git a/wsmeext/tests/test_sqlalchemy_controllers.py b/wsmeext/tests/test_sqlalchemy_controllers.py new file mode 100644 index 0000000..401ada1 --- /dev/null +++ b/wsmeext/tests/test_sqlalchemy_controllers.py @@ -0,0 +1,224 @@ +import datetime + +try: + import json +except ImportError: + import simplejson as json + +from webtest import TestApp + +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import Column, Integer, Unicode, Date, ForeignKey +from sqlalchemy.orm import relation + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker, scoped_session + +from wsme import WSRoot +import wsme.types + +from wsmeext.sqlalchemy.types import generate_types +from wsmeext.sqlalchemy.controllers import CRUDController + +from six import u + +engine = create_engine('sqlite:///') +DBSession = scoped_session(sessionmaker(autocommit=False, + autoflush=False, + bind=engine)) +DBBase = declarative_base() + +registry = wsme.types.Registry() + + +class DBPerson(DBBase): + __tablename__ = 'person' + + id = Column(Integer, primary_key=True) + name = Column(Unicode(50)) + birthdate = Column(Date) + + addresses = relation('DBAddress') + + +class DBAddress(DBBase): + __tablename__ = 'address' + + id = Column(Integer, primary_key=True) + + _person_id = Column('person_id', ForeignKey(DBPerson.id)) + + street = Column(Unicode(50)) + city = Column(Unicode(50)) + + person = relation(DBPerson) + + +globals().update( + generate_types(DBPerson, DBAddress, + makename=lambda s: s[2:], registry=registry)) + + +class PersonController(CRUDController): + __saclass__ = DBPerson + __dbsession__ = DBSession + __registry__ = registry + + +class AddressController(CRUDController): + __saclass__ = DBAddress + __dbsession__ = DBSession + __registry__ = registry + + +class Root(WSRoot): + __registry__ = registry + + person = PersonController() + address = AddressController() + + +class TestCRUDController(): + def setUp(self): + DBBase.metadata.create_all(DBSession.bind) + + self.root = Root() + self.root.getapi() + self.root.addprotocol('restjson') + + self.app = TestApp(self.root.wsgiapp()) + + def tearDown(self): + DBBase.metadata.drop_all(DBSession.bind) + + def test_create(self): + data = dict(data=dict( + name=u('Pierre-Joseph'), + birthdate=u('1809-01-15') + )) + r = self.app.post('/person/create', json.dumps(data), + headers={ + 'Content-Type': 'application/json' + }) + r = json.loads(r.text) + print(r) + assert r['name'] == u('Pierre-Joseph') + assert r['birthdate'] == u('1809-01-15') + + def test_PUT(self): + data = dict(data=dict( + name=u('Pierre-Joseph'), + birthdate=u('1809-01-15') + )) + r = self.app.put('/person', json.dumps(data), + headers={ + 'Content-Type': 'application/json' + }) + r = json.loads(r.text) + print(r) + assert r['name'] == u('Pierre-Joseph') + assert r['birthdate'] == u('1809-01-15') + + def test_read(self): + p = DBPerson( + name=u('Pierre-Joseph'), + birthdate=datetime.date(1809, 1, 15)) + DBSession.add(p) + DBSession.flush() + pid = p.id + r = self.app.post('/person/read', '{"ref": {"id": %s}}' % pid, + headers={ + 'Content-Type': 'application/json' + }) + r = json.loads(r.text) + print(r) + assert r['name'] == u('Pierre-Joseph') + assert r['birthdate'] == u('1809-01-15') + + def test_GET(self): + p = DBPerson( + name=u('Pierre-Joseph'), + birthdate=datetime.date(1809, 1, 15)) + DBSession.add(p) + DBSession.flush() + pid = p.id + r = self.app.get('/person?ref.id=%s' % pid, + headers={ + 'Content-Type': 'application/json' + }) + r = json.loads(r.text) + print(r) + assert r['name'] == u('Pierre-Joseph') + assert r['birthdate'] == u('1809-01-15') + + def test_update(self): + p = DBPerson( + name=u('Pierre-Joseph'), + birthdate=datetime.date(1809, 1, 15)) + DBSession.add(p) + DBSession.flush() + pid = p.id + data = { + "id": pid, + "name": u('Pierre-Joseph Proudon') + } + r = self.app.post('/person/update', json.dumps(dict(data=data)), + headers={ + 'Content-Type': 'application/json' + }) + r = json.loads(r.text) + print(r) + assert r['name'] == u('Pierre-Joseph Proudon') + assert r['birthdate'] == u('1809-01-15') + + def test_POST(self): + p = DBPerson( + name=u('Pierre-Joseph'), + birthdate=datetime.date(1809, 1, 15)) + DBSession.add(p) + DBSession.flush() + pid = p.id + data = { + "id": pid, + "name": u('Pierre-Joseph Proudon') + } + r = self.app.post('/person', json.dumps(dict(data=data)), + headers={ + 'Content-Type': 'application/json' + }) + r = json.loads(r.text) + print(r) + assert r['name'] == u('Pierre-Joseph Proudon') + assert r['birthdate'] == u('1809-01-15') + + def test_delete(self): + p = DBPerson( + name=u('Pierre-Joseph'), + birthdate=datetime.date(1809, 1, 15)) + DBSession.add(p) + DBSession.flush() + pid = p.id + r = self.app.post('/person/delete', json.dumps( + dict(ref=dict(id=pid))), + headers={ + 'Content-Type': 'application/json' + }) + print(r) + assert DBSession.query(DBPerson).get(pid) is None + + def test_DELETE(self): + p = DBPerson( + name=u('Pierre-Joseph'), + birthdate=datetime.date(1809, 1, 15)) + DBSession.add(p) + DBSession.flush() + pid = p.id + r = self.app.delete('/person?ref.id=%s' % pid, + headers={ + 'Content-Type': 'application/json' + }) + print(r) + assert DBSession.query(DBPerson).get(pid) is None + + def test_nothing(self): + pass diff --git a/wsmeext/tests/test_sqlalchemy_types.py b/wsmeext/tests/test_sqlalchemy_types.py new file mode 100644 index 0000000..8512015 --- /dev/null +++ b/wsmeext/tests/test_sqlalchemy_types.py @@ -0,0 +1,72 @@ +import datetime + +import wsmeext.sqlalchemy.types + +from wsme.types import text, Unset, isarray + +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import Column, Integer, String, Date, ForeignKey +from sqlalchemy.orm import relation + +from six import u + +SABase = declarative_base() + + +class SomeClass(SABase): + __tablename__ = 'some_table' + id = Column(Integer, primary_key=True) + name = Column(String(50)) + + adate = Column(Date) + + +def test_complextype(): + class AType(wsmeext.sqlalchemy.types.Base): + __saclass__ = SomeClass + + assert AType.id.datatype is int + assert AType.name.datatype is text + assert AType.adate.datatype is datetime.date + + a = AType() + s = SomeClass(name=u('aname'), adate=datetime.date(2012, 6, 26)) + assert s.name == u('aname') + + a.from_instance(s) + assert a.name == u('aname') + assert a.adate == datetime.date(2012, 6, 26) + + a.name = u('test') + del a.adate + assert a.adate is Unset + + a.to_instance(s) + assert s.name == u('test') + assert s.adate == datetime.date(2012, 6, 26) + + +def test_generate(): + class A(SABase): + __tablename__ = 'a' + id = Column(Integer, primary_key=True) + name = Column(String(50)) + + _b_id = Column(ForeignKey('b.id')) + + b = relation('B') + + class B(SABase): + __tablename__ = 'b' + id = Column(Integer, primary_key=True) + name = Column(String(50)) + + alist = relation(A) + + newtypes = wsmeext.sqlalchemy.types.generate_types(A, B) + + assert newtypes['A'].id.datatype is int + assert newtypes['A'].b.datatype is newtypes['B'] + assert newtypes['B'].id.datatype is int + assert isarray(newtypes['B'].alist.datatype) + assert newtypes['B'].alist.datatype.item_type is newtypes['A']