solar/solar/dblayer/sql_client.py
Jedrzej Nowak fb1e946853 Adjusted dblayer sqlite behaviour to riak one
Changed default database to sqlite
Changed how solar_db config is defined
Removed lupa from test-requirements
Removed riak from requirements
Testr uses :memory: sqlite

Closes-Bug: #1526286

Change-Id: I709d19a192f800e9a67d9c7657f286ff0b343053
2015-12-17 12:18:59 +01:00

455 lines
14 KiB
Python

# -*- coding: utf-8 -*-
# Copyright 2015 Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
# msgpack is way faster but less readable
# using json for easier debug
import json
import sys
import uuid
from peewee import BlobField
from peewee import CharField
from peewee import ForeignKeyField
from peewee import Model
from solar.dblayer.model import clear_cache
from threading import RLock
encoder = json.dumps
def wrapped_loads(data, *args, **kwargs):
if not isinstance(data, basestring):
data = str(data)
return json.loads(data, *args, **kwargs)
decoder = wrapped_loads
class _DataField(BlobField):
def db_value(self, value):
return super(_DataField, self).db_value(encoder(value))
def python_value(self, value):
return decoder(super(_DataField, self).python_value(value))
class _LinksField(_DataField):
def db_value(self, value):
return super(_LinksField, self).db_value(list(value))
def python_value(self, value):
ret = super(_LinksField, self).python_value(value)
return [tuple(e) for e in ret]
class _SqlBucket(Model):
def __init__(self, *args, **kwargs):
self._new = kwargs.pop('_new', False)
ed = kwargs.pop('encoded_data', None)
if ed:
self.encoded_data = ed
if 'data' not in kwargs:
kwargs['data'] = {}
super(_SqlBucket, self).__init__(*args, **kwargs)
key = CharField(primary_key=True, null=False)
data = _DataField(null=False)
vclock = CharField(max_length=32, null=False)
links = _LinksField(null=False, default=list)
@property
def encoded_data(self):
return self.data.get('_encoded_data')
@encoded_data.setter
def encoded_data(self, value):
self.data['_encoded_data'] = value
def save(self, force_insert=False, only=None):
if self._new:
force_insert = True
self._new = False
ret = super(_SqlBucket, self).save(force_insert, only)
return ret
@property
def sql_session(self):
return self.bucket.sql_session
class FieldWrp(object):
def __init__(self, name):
self.name = name
def __get__(self, instance, owner):
return getattr(instance._sql_bucket_obj, self.name)
def __set__(self, instance, value):
setattr(instance._sql_bucket_obj, self.name, value)
class _SqlIdx(Model):
name = CharField(null=False, index=True)
value = CharField(null=False, index=True)
class RiakObj(object):
key = FieldWrp('key')
data = FieldWrp('data')
vclock = FieldWrp('vclock')
links = FieldWrp('links')
encoded_data = FieldWrp('encoded_data')
def __init__(self, sql_bucket_obj, new=False):
self._sql_bucket_obj = sql_bucket_obj
self.new = sql_bucket_obj._new
self.fetch_indexes()
@property
def sql_session(self):
return self._sql_bucket_obj.sql_session
@property
def bucket(self):
return self._sql_bucket_obj.bucket
@property
def indexes(self):
self.fetch_indexes()
return self._indexes
def fetch_indexes(self):
if not hasattr(self, '_indexes'):
idxes = self.bucket._sql_idx.select().where(
self.bucket._sql_idx.key == self.key)
self._indexes = set((idx.name, idx.value) for idx in idxes)
@indexes.setter
def indexes(self, value):
assert isinstance(value, set)
self._indexes = value
def _save_indexes(self):
# TODO: possible optimization
# update only what's needed
# don't delete all at first
q = self.bucket._sql_idx.delete().where(self.bucket._sql_idx.key ==
self.key)
q.execute()
for iname, ival in self.indexes:
idx = self.bucket._sql_idx(key=self.key, name=iname, value=ival)
idx.save()
def add_index(self, field, value):
self.indexes.add((field, value))
return self
def set_index(self, field, value):
to_rem = set((x for x in self.indexes if x[0] == field))
self.indexes.difference_update(to_rem)
return self.add_index(field, value)
def remove_index(self, field=None, value=None):
if field is None and value is None:
# q = self.bucket._sql_idx.delete().where(
# self.bucket._sql_idx.key == self.key)
# q.execute()
self.indexes.clear()
elif field is not None and value is None:
# q = self.bucket._sql_idx.delete().where(
# (self.bucket._sql_idx.key == self.key) &
# (self.bucket._sql_idx.name == field))
# q.execute()
to_rem = set((x for x in self.indexes if x[0] == field))
self.indexes.difference_update(to_rem)
elif field is not None and value is not None:
# q = self.bucket._sql_idx.delete().where(
# (self.bucket._sql_idx.key == self.key) &
# (self.bucket._sql_idx.name == field) &
# (self.bucket._sql_idx.value == value))
# q.execute()
to_rem = set((
x for x in self.indexes if x[0] == field and x[1] == value))
self.indexes.difference_update(to_rem)
return self
def store(self, return_body=True):
self.vclock = uuid.uuid4().hex
assert self._sql_bucket_obj is not None
self._sql_bucket_obj.save()
self._save_indexes()
return self
def delete(self):
self.bucket.delete(self.key)
return self
@property
def exists(self):
return not self.new
def get_link(self, tag):
return next(x[1] for x in self.links if x[2] == tag)
def set_link(self, obj, tag=None):
if isinstance(obj, tuple):
newlink = obj
else:
newlink = (obj.bucket.name, obj.key, tag)
multi = [x for x in self.links if x[0:1] == newlink[0:1]]
for item in multi:
self.links.remove(item)
self.links.append(newlink)
return self
def del_link(self, obj=None, tag=None):
assert obj is not None or tag is not None
if tag is not None:
links = [x for x in self.links if x[2] != tag]
else:
links = self.links
if obj is not None:
if not isinstance(obj, tuple):
obj = (obj.bucket.name, obj.key, tag)
links = [x for x in links if x[0:1] == obj[0:1]]
self.links = links
return self
class IndexPage(object):
def __init__(self, index, results, return_terms, max_results,
continuation):
self.max_results = max_results
self.index = index
if not return_terms:
self.results = list(x[0] for x in results)
else:
self.results = list(results)
if not max_results or not self.results:
self.continuation = None
else:
self.continuation = str(continuation + len(self.results))
self.return_terms = return_terms
def __len__(self):
return len(self.results)
def __getitem__(self, item):
return self.results[item]
class Bucket(object):
def __init__(self, name, client):
self.client = client
table_name = "bucket_%s" % name.lower()
self.name = table_name
idx_table_name = 'idx_%s' % name.lower()
class ModelMeta(object):
db_table = table_name
database = self.client.sql_session
self._sql_model = type(table_name, (_SqlBucket, ), {'Meta': ModelMeta,
'bucket': self})
_idx_key = ForeignKeyField(self._sql_model, null=False, index=True)
class IdxMeta(object):
db_table = idx_table_name
database = self.client.sql_session
self._sql_idx = type(idx_table_name, (_SqlIdx, ), {'Meta': IdxMeta,
'bucket': self,
'key': _idx_key})
def search(self, q, rows=10, start=0, sort=''):
raise NotImplementedError()
def create_search(self, index):
raise NotImplementedError()
def set_property(self, name, value):
return
def get_properties(self):
return {'search_index': False}
def get(self, key):
try:
ret = self._sql_model.get(self._sql_model.key == key)
except self._sql_model.DoesNotExist:
ret = None
new = ret is None
if new:
ret = self._sql_model(key=key, _new=new)
return RiakObj(ret, new)
def delete(self, data, *args, **kwargs):
if isinstance(data, basestring):
key = data
else:
key = data.key
self._sql_model.delete().where(self._sql_model.key == key).execute()
self._sql_idx.delete().where(self._sql_idx.key == key).execute()
return self
def new(self, key, data=None, encoded_data=None, **kwargs):
if key is not None:
try:
ret = self._sql_model.get(self._sql_model.key == key)
except self._sql_model.DoesNotExist:
ret = None
new = ret is None
else:
key = uuid.uuid4().hex
new = True
if new:
ret = self._sql_model(key=key, _new=new)
ret.key = key
ret.data = data if data is not None else {}
if encoded_data:
ret.encoded_data = encoded_data
ret.links = []
ret.vclock = "new"
return RiakObj(ret, new)
def get_index(self,
index,
startkey,
endkey=None,
return_terms=None,
max_results=None,
continuation=None,
timeout=None,
fmt=None,
term_regex=None):
if startkey and endkey is None:
endkey = startkey
if startkey > endkey:
startkey, endkey = endkey, startkey
if index == '$key':
if return_terms:
q = self._sql_model.select(self._sql_model.value,
self._sql_model.key)
else:
q = self._sql_model.select(self._sql_model.key)
q = q.where(
self._sql_model.key >= startkey,
self._sql_model.key <= endkey).order_by(self._sql_model.key)
elif index == '$bucket':
if return_terms:
q = self._sql_model.select(self._sql_model.value,
self._sql_model.key)
else:
q = self._sql_model.select(self._sql_model.key)
if not startkey == '_' and endkey == '_':
q = q.where(self._sql_model.key >= startkey,
self._sql_model.key <= endkey)
else:
if return_terms:
q = self._sql_idx.select(self._sql_idx.value,
self._sql_idx.key)
else:
q = self._sql_idx.select(self._sql_idx.key)
q = q.where(
self._sql_idx.name == index, self._sql_idx.value >= startkey,
self._sql_idx.value <= endkey).order_by(self._sql_idx.value)
max_results = int(max_results or 0)
continuation = int(continuation or 0)
if max_results:
q = q.limit(max_results)
if continuation:
q = q.offset(continuation)
q = q.tuples()
return IndexPage(index, q, return_terms, max_results, continuation)
def multiget(self, keys):
if not keys:
return []
else:
q = self._sql_model.select().where(self._sql_model.key << list(
keys))
return map(RiakObj, list(q))
@property
def sql_session(self):
return self.client.sql_session
class SqlClient(object):
block = RLock()
search_dir = None
def __init__(self, *args, **kwargs):
db_class_str = kwargs.pop("db_class", 'SqliteDatabase')
try:
mod, fromlist = db_class_str.split('.')
except ValueError:
mod = 'peewee'
fromlist = db_class_str
__import__(mod, fromlist=[fromlist])
db_class = getattr(sys.modules[mod], fromlist)
session = db_class(*args, **kwargs)
self._sql_session = session
self.buckets = {}
def bucket(self, name):
with self.block:
if name not in self.buckets:
b = Bucket(name, self)
b._sql_model.create_table(fail_silently=True)
b._sql_idx.create_table(fail_silently=True)
self.buckets[name] = b
return self.buckets[name]
@property
def sql_session(self):
return self._sql_session
def session_start(self):
clear_cache()
sess = self._sql_session
# TODO: (jnowak) remove this, it's a hack
# because of pytest nested calls
if getattr(sess, '_started', False):
sess.begin()
setattr(sess, '_started', True)
def session_end(self, result=True):
sess = self._sql_session
if result:
sess.commit()
else:
sess.rollback()
clear_cache()
setattr(sess, '_started', False)
def delete_all(self, cls):
# naive way for SQL, we could delete whole table contents
rst = cls.bucket.get_index('$bucket',
startkey='_',
max_results=100000).results
for key in rst:
cls.bucket.delete(key)