DB api refactoring

This change does not create anything new.

The db api file was growing too large, so now it is a module with
separate files corresponding to each db model.

Change-Id: I4e032536d5389606d6edd7c0eb80c58ccc159447
This commit is contained in:
Nikita Konovalov 2014-03-26 14:39:40 +04:00
parent 280f90389c
commit bbd477c687
23 changed files with 643 additions and 542 deletions

View File

@ -16,7 +16,7 @@
from pecan import request
from storyboard.api.auth.token_storage import storage
from storyboard.db import api as dbapi
from storyboard.db.api import users as user_api
def guest():
@ -43,6 +43,6 @@ def superuser():
token = request.authorization[1]
token_info = token_storage.get_access_token_info(token)
user = dbapi.user_get(token_info.user_id)
user = user_api.user_get(token_info.user_id)
return user.is_superuser

View File

@ -21,7 +21,7 @@ from oauthlib.oauth2 import WebApplicationServer
from oslo.config import cfg
from storyboard.api.auth.token_storage import storage
from storyboard.db import api as db_api
from storyboard.db.api import users as user_api
CONF = cfg.CONF
LOG = logging.getLogger(__name__)
@ -106,7 +106,7 @@ class SkeletonValidator(RequestValidator):
username = request._params["openid.sreg.nickname"]
last_login = datetime.now()
user = db_api.user_get_by_openid(openid)
user = user_api.user_get_by_openid(openid)
user_dict = {"full_name": full_name,
"username": username,
"email": email,
@ -114,9 +114,9 @@ class SkeletonValidator(RequestValidator):
if not user:
user_dict.update({"openid": openid})
user = db_api.user_create(user_dict)
user = user_api.user_create(user_dict)
else:
user = db_api.user_update(user.id, user_dict)
user = user_api.user_update(user.id, user_dict)
self.token_storage.save_authorization_code(code, user_id=user.id)

View File

@ -16,7 +16,7 @@
import datetime
from storyboard.api.auth.token_storage import storage
from storyboard.db import api as db_api
from storyboard.db.api import auth as auth_api
class DBTokenStorage(storage.StorageBase):
@ -26,17 +26,17 @@ class DBTokenStorage(storage.StorageBase):
"state": authorization_code["state"],
"user_id": user_id
}
db_api.authorization_code_save(values)
auth_api.authorization_code_save(values)
def get_authorization_code_info(self, code):
return db_api.authorization_code_get(code)
return auth_api.authorization_code_get(code)
def check_authorization_code(self, code):
db_code = db_api.authorization_code_get(code)
db_code = auth_api.authorization_code_get(code)
return not db_code is None
def invalidate_authorization_code(self, code):
db_api.authorization_code_delete(code)
auth_api.authorization_code_delete(code)
def save_token(self, access_token, expires_in, refresh_token, user_id):
values = {
@ -48,22 +48,22 @@ class DBTokenStorage(storage.StorageBase):
"user_id": user_id
}
db_api.token_save(values)
auth_api.token_save(values)
def get_access_token_info(self, access_token):
return db_api.token_get(access_token)
return auth_api.token_get(access_token)
def check_access_token(self, access_token):
token_info = db_api.token_get(access_token)
token_info = auth_api.token_get(access_token)
if not token_info:
return False
if datetime.datetime.now() > token_info.expires_at:
db_api.token_update(access_token, {"is_active": False})
auth_api.token_update(access_token, {"is_active": False})
return False
return token_info.is_active
def remove_token(self, access_token):
db_api.token_delete(access_token)
auth_api.token_delete(access_token)

View File

@ -23,7 +23,7 @@ import wsmeext.pecan as wsme_pecan
from storyboard.api.auth import authorization_checks as checks
from storyboard.api.v1 import base
from storyboard.db import api as dbapi
from storyboard.db.api import comments as comments_api
class Comment(base.APIBase):
@ -64,7 +64,7 @@ class CommentsController(rest.RestController):
It will stay unused, as far as comments have their own unique ids
:param comment_id: An ID of the comment.
"""
comment = dbapi.comment_get(comment_id)
comment = comments_api.comment_get(comment_id)
if comment:
return Comment.from_db_model(comment)
@ -79,7 +79,7 @@ class CommentsController(rest.RestController):
:param story_id: filter comments by story ID.
"""
comments = dbapi.comment_get_all(story_id=story_id)
comments = comments_api.comment_get_all(story_id=story_id)
return [Comment.from_db_model(comment) for comment in comments]
@secure(checks.authenticated)
@ -92,7 +92,7 @@ class CommentsController(rest.RestController):
"""
comment.story_id = story_id
comment.author_id = request.current_user_id
created_comment = dbapi.comment_create(comment.as_dict())
created_comment = comments_api.comment_create(comment.as_dict())
return Comment.from_db_model(created_comment)
@secure(checks.authenticated)
@ -104,15 +104,15 @@ class CommentsController(rest.RestController):
:param comment_id: the id of a Comment to be updated
:param comment_body: an updated Comment
"""
comment = dbapi.comment_get(comment_id)
comment = comments_api.comment_get(comment_id)
if request.current_user_id != comment.author_id:
response.status_code = 400
response.body = "You are not allowed to update this comment."
return response
updated_comment = dbapi.comment_update(comment_id,
comment_body.as_dict())
updated_comment = comments_api.comment_update(comment_id,
comment_body.as_dict())
return Comment.from_db_model(updated_comment)
@ -124,14 +124,14 @@ class CommentsController(rest.RestController):
:param story_id: a placeholder
:param comment_id: the id of a Comment to be updated
"""
comment = dbapi.comment_get(comment_id)
comment = comments_api.comment_get(comment_id)
if request.current_user_id != comment.author_id:
response.status_code = 400
response.body = "You are not allowed to delete this comment."
return response
dbapi.comment_delete(comment_id)
comments_api.comment_delete(comment_id)
response.status_code = 204
return response

View File

@ -23,7 +23,7 @@ import wsmeext.pecan as wsme_pecan
from storyboard.api.auth import authorization_checks as checks
from storyboard.api.v1 import base
from storyboard.db import api as dbapi
from storyboard.db.api import projects as projects_api
CONF = cfg.CONF
@ -71,7 +71,7 @@ class ProjectsController(rest.RestController):
:param project_id: project ID.
"""
project = dbapi.project_get(project_id)
project = projects_api.project_get(project_id)
if project:
return Project.from_db_model(project)
@ -94,10 +94,11 @@ class ProjectsController(rest.RestController):
limit = min(CONF.page_size_maximum, max(1, limit))
# Resolve the marker record.
marker_project = dbapi.project_get(marker)
marker_project = projects_api.project_get(marker)
projects = dbapi.project_get_all(marker=marker_project, limit=limit)
project_count = dbapi.project_get_count()
projects = projects_api.project_get_all(marker=marker_project,
limit=limit)
project_count = projects_api.project_get_count()
# Apply the query response headers.
response.headers['X-Limit'] = str(limit)
@ -114,7 +115,7 @@ class ProjectsController(rest.RestController):
:param project: a project within the request body.
"""
result = dbapi.project_create(project.as_dict())
result = projects_api.project_create(project.as_dict())
return Project.from_db_model(result)
@secure(checks.superuser)
@ -125,8 +126,8 @@ class ProjectsController(rest.RestController):
:param project_id: An ID of the project.
:param project: a project within the request body.
"""
result = dbapi.project_update(project_id,
project.as_dict(omit_unset=True))
result = projects_api.project_update(project_id,
project.as_dict(omit_unset=True))
if result:
return Project.from_db_model(result)
@ -141,6 +142,6 @@ class ProjectsController(rest.RestController):
:param project_id: An ID of the project.
"""
dbapi.project_delete(project_id)
projects_api.project_delete(project_id)
response.status_code = 204

View File

@ -25,7 +25,7 @@ import wsmeext.pecan as wsme_pecan
from storyboard.api.auth import authorization_checks as checks
from storyboard.api.v1 import base
from storyboard.api.v1.comments import CommentsController
from storyboard.db import api as dbapi
from storyboard.db.api import stories as stories_api
CONF = cfg.CONF
@ -75,7 +75,7 @@ class StoriesController(rest.RestController):
:param story_id: An ID of the story.
"""
story = dbapi.story_get(story_id)
story = stories_api.story_get(story_id)
if story:
return Story.from_db_model(story)
@ -100,15 +100,15 @@ class StoriesController(rest.RestController):
limit = min(CONF.page_size_maximum, max(1, limit))
# Resolve the marker record.
marker_story = dbapi.story_get(marker)
marker_story = stories_api.story_get(marker)
if marker_story is None or marker_story.project_id != project_id:
marker_story = None
stories = dbapi.story_get_all(marker=marker_story,
limit=limit,
project_id=project_id)
story_count = dbapi.story_get_count(project_id=project_id)
stories = stories_api.story_get_all(marker=marker_story,
limit=limit,
project_id=project_id)
story_count = stories_api.story_get_count(project_id=project_id)
# Apply the query response headers.
response.headers['X-Limit'] = str(limit)
@ -129,7 +129,7 @@ class StoriesController(rest.RestController):
user_id = request.current_user_id
story_dict.update({"creator_id": user_id})
created_story = dbapi.story_create(story_dict)
created_story = stories_api.story_create(story_dict)
return Story.from_db_model(created_story)
@ -141,8 +141,9 @@ class StoriesController(rest.RestController):
:param story_id: An ID of the story.
:param story: a story within the request body.
"""
updated_story = dbapi.story_update(story_id,
story.as_dict(omit_unset=True))
updated_story = stories_api.story_update(
story_id,
story.as_dict(omit_unset=True))
if updated_story:
return Story.from_db_model(updated_story)
@ -157,7 +158,7 @@ class StoriesController(rest.RestController):
:param story_id: An ID of the story.
"""
dbapi.story_delete(story_id)
stories_api.story_delete(story_id)
response.status_code = 204

View File

@ -23,7 +23,7 @@ import wsmeext.pecan as wsme_pecan
from storyboard.api.auth import authorization_checks as checks
from storyboard.api.v1 import base
from storyboard.db import api as dbapi
from storyboard.db.api import tasks as tasks_api
CONF = cfg.CONF
@ -64,7 +64,7 @@ class TasksController(rest.RestController):
:param task_id: An ID of the task.
"""
task = dbapi.task_get(task_id)
task = tasks_api.task_get(task_id)
if task:
return Task.from_db_model(task)
@ -88,15 +88,15 @@ class TasksController(rest.RestController):
limit = min(CONF.page_size_maximum, max(1, limit))
# Resolve the marker record.
marker_task = dbapi.task_get(marker)
marker_task = tasks_api.task_get(marker)
if marker_task is None or marker_task.story_id != story_id:
marker_task = None
tasks = dbapi.task_get_all(marker=marker_task,
limit=limit,
story_id=story_id)
task_count = dbapi.task_get_count(story_id=story_id)
tasks = tasks_api.task_get_all(marker=marker_task,
limit=limit,
story_id=story_id)
task_count = tasks_api.task_get_count(story_id=story_id)
# Apply the query response headers.
response.headers['X-Limit'] = str(limit)
@ -113,7 +113,7 @@ class TasksController(rest.RestController):
:param task: a task within the request body.
"""
created_task = dbapi.task_create(task.as_dict())
created_task = tasks_api.task_create(task.as_dict())
return Task.from_db_model(created_task)
@secure(checks.authenticated)
@ -124,7 +124,7 @@ class TasksController(rest.RestController):
:param task_id: An ID of the task.
:param task: a task within the request body.
"""
updated_task = dbapi.task_update(task_id,
updated_task = tasks_api.task_update(task_id,
task.as_dict(omit_unset=True))
if updated_task:
@ -140,6 +140,6 @@ class TasksController(rest.RestController):
:param task_id: An ID of the task.
"""
dbapi.task_delete(task_id)
tasks_api.task_delete(task_id)
response.status_code = 204

View File

@ -26,7 +26,7 @@ import wsmeext.pecan as wsme_pecan
from storyboard.api.auth import authorization_checks as checks
from storyboard.api.v1 import base
from storyboard.db import api as dbapi
from storyboard.db.api import users as users_api
CONF = cfg.CONF
@ -85,10 +85,10 @@ class UsersController(rest.RestController):
limit = min(CONF.page_size_maximum, max(1, limit))
# Resolve the marker record.
marker_user = dbapi.user_get(marker)
marker_user = users_api.user_get(marker)
users = dbapi.user_get_all(marker=marker_user, limit=limit)
user_count = dbapi.user_get_count()
users = users_api.user_get_all(marker=marker_user, limit=limit)
user_count = users_api.user_get_count()
# Apply the query response headers.
response.headers['X-Limit'] = str(limit)
@ -110,7 +110,7 @@ class UsersController(rest.RestController):
if user_id == request.current_user_id:
filter_non_public = False
user = dbapi.user_get(user_id, filter_non_public)
user = users_api.user_get(user_id, filter_non_public)
if not user:
raise ClientSideError("User %s not found" % user_id,
status_code=404)
@ -124,7 +124,7 @@ class UsersController(rest.RestController):
:param user: a user within the request body.
"""
created_user = dbapi.user_create(user.as_dict())
created_user = users_api.user_create(user.as_dict())
return User.from_db_model(created_user)
@secure(checks.authenticated)
@ -147,5 +147,5 @@ class UsersController(rest.RestController):
"your identity fields."
return response
updated_user = dbapi.user_update(user_id, user_dict)
updated_user = users_api.user_update(user_id, user_dict)
return User.from_db_model(updated_user)

View File

@ -20,7 +20,7 @@ from wsme import types as wtypes
from oslo.config import cfg
from sqlalchemy.exc import SADeprecationWarning
from storyboard.db.api import get_session
from storyboard.db.api.base import get_session
import storyboard.db.models as sqlalchemy_models

View File

@ -1,453 +0,0 @@
# Copyright (c) 2014 Mirantis Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from oslo.config import cfg
import six
from storyboard.common import exception as exc
from storyboard.db import models
from storyboard.openstack.common.db import exception as db_exc
from storyboard.openstack.common.db.sqlalchemy import session as db_session
from storyboard.openstack.common.db.sqlalchemy.utils import paginate_query
from storyboard.openstack.common import log
CONF = cfg.CONF
CONF.import_group("database", "storyboard.openstack.common.db.options")
LOG = log.getLogger(__name__)
_FACADE = None
BASE = models.Base
def setup_db():
try:
engine = get_engine()
BASE.metadata.create_all(engine)
except Exception as e:
LOG.exception("Database registration exception: %s", e)
return False
return True
def drop_db():
try:
BASE.metadata.drop_all(get_engine())
except Exception as e:
LOG.exception("Database shutdown exception: %s", e)
return False
return True
def _get_facade_instance():
"""Generate an instance of the DB Facade.
"""
global _FACADE
if _FACADE is None:
_FACADE = db_session.EngineFacade(
CONF.database.connection,
**dict(CONF.database.iteritems()))
return _FACADE
def _destroy_facade_instance():
"""Destroys the db facade instance currently in use.
"""
global _FACADE
_FACADE = None
def get_engine():
"""Returns the global instance of our database engine.
"""
facade = _get_facade_instance()
return facade.get_engine()
def get_session(autocommit=True, expire_on_commit=False):
"""Returns a database session from our facade.
"""
facade = _get_facade_instance()
return facade.get_session(autocommit=autocommit,
expire_on_commit=expire_on_commit)
def cleanup():
"""Manually clean up our database engine.
"""
_destroy_facade_instance()
def model_query(model, session=None):
"""Query helper.
:param model: base model to query
"""
session = session or get_session()
query = session.query(model)
return query
def __entity_get(kls, entity_id, session):
query = model_query(kls, session)
return query.filter_by(id=entity_id, is_active=True).first()
def _entity_get(kls, entity_id, filter_non_public=False):
entity = __entity_get(kls, entity_id, get_session())
if filter_non_public:
entity = _filter_non_public_fields(entity, entity._public_fields)
return entity
def _entity_get_all(kls, filter_non_public=False, marker=None, limit=None,
**kwargs):
# Sanity check on input parameters
kwargs = dict((k, v) for k, v in kwargs.iteritems() if v)
# Construct the query
query = model_query(kls).filter_by(**kwargs)
query = paginate_query(query=query,
model=kls,
limit=limit,
sort_keys=['id'],
marker=marker,
sort_dir='asc')
# Execute the query
entities = query.all()
if len(entities) > 0 and filter_non_public:
sample_entity = entities[0] if len(entities) > 0 else None
public_fields = getattr(sample_entity, "_public_fields", [])
entities = [_filter_non_public_fields(entity, public_fields)
for entity in entities]
return entities
def _entity_get_count(kls, **kwargs):
kwargs = dict((k, v) for k, v in kwargs.iteritems() if v)
count = model_query(kls).filter_by(**kwargs).count()
return count
def _filter_non_public_fields(entity, public_list=list()):
ent_copy = copy.copy(entity)
for attr_name, val in six.iteritems(entity.__dict__):
if attr_name.startswith("_"):
continue
if attr_name not in public_list:
delattr(ent_copy, attr_name)
return ent_copy
def _entity_create(kls, values):
entity = kls()
entity.update(values.copy())
session = get_session()
with session.begin():
try:
session.add(entity)
except db_exc.DBDuplicateEntry:
raise exc.DuplicateEntry("Duplicate entry for : %s"
% kls.__name__)
return entity
def _entity_update(kls, entity_id, values):
session = get_session()
with session.begin():
entity = __entity_get(kls, entity_id, session)
if entity is None:
raise exc.NotFound("%s %s not found" % (kls.__name__, entity_id))
entity.update(values.copy())
session.add(entity)
return entity
## BEGIN Users
def user_get(user_id, filter_non_public=False):
entity = _entity_get(models.User, user_id,
filter_non_public=filter_non_public)
return entity
def user_get_all(marker=None, limit=None, filter_non_public=False):
return _entity_get_all(models.User,
marker=marker,
limit=limit,
filter_non_public=filter_non_public)
def user_get_count():
return _entity_get_count(models.User)
def user_get_by_openid(openid):
query = model_query(models.User, get_session())
return query.filter_by(openid=openid).first()
def user_create(values):
user = models.User()
user.update(values.copy())
session = get_session()
with session.begin():
try:
user.save(session=session)
except db_exc.DBDuplicateEntry as e:
raise exc.DuplicateEntry("Duplicate entry for User: %s"
% e.columns)
return user
def user_update(user_id, values):
return _entity_update(models.User, user_id, values)
## BEGIN Projects
def project_get(project_id):
return _entity_get(models.Project, project_id)
def project_get_all(marker=None, limit=None, **kwargs):
return _entity_get_all(models.Project,
is_active=True,
marker=marker,
limit=limit,
**kwargs)
def project_get_count(**kwargs):
return _entity_get_count(models.Project, is_active=True, **kwargs)
def project_create(values):
return _entity_create(models.Project, values)
def project_update(project_id, values):
return _entity_update(models.Project, project_id, values)
def project_delete(project_id):
project = project_get(project_id)
if project:
project.is_active = False
_entity_update(models.Project, project_id, project.as_dict())
# BEGIN Stories
def story_get(story_id):
return _entity_get(models.Story, story_id)
def story_get_all(marker=None, limit=None, project_id=None):
if project_id:
return _story_get_all_in_project(marker=marker,
limit=limit,
project_id=project_id)
else:
return _entity_get_all(models.Story, is_active=True,
marker=marker, limit=limit)
def story_get_count(project_id=None):
if project_id:
return _story_get_count_in_project(project_id)
else:
return _entity_get_count(models.Story, is_active=True)
def _story_get_all_in_project(project_id, marker=None, limit=None):
session = get_session()
sub_query = model_query(models.Task.story_id, session) \
.filter_by(project_id=project_id, is_active=True) \
.distinct(True) \
.subquery()
query = model_query(models.Story, session) \
.filter_by(is_active=True) \
.join(sub_query, models.Story.tasks)
query = paginate_query(query=query,
model=models.Story,
limit=limit,
sort_keys=['id'],
marker=marker,
sort_dir='asc')
return query.all()
def _story_get_count_in_project(project_id):
session = get_session()
sub_query = model_query(models.Task.story_id, session) \
.filter_by(project_id=project_id, is_active=True) \
.distinct(True) \
.subquery()
query = model_query(models.Story, session) \
.filter_by(is_active=True) \
.join(sub_query, models.Story.tasks)
return query.count()
def story_create(values):
return _entity_create(models.Story, values)
def story_update(story_id, values):
return _entity_update(models.Story, story_id, values)
def story_delete(story_id):
story = story_get(story_id)
if story:
story.is_active = False
_entity_update(models.Story, story_id, story.as_dict())
# BEGIN Comments
def comment_get(comment_id):
return _entity_get(models.Comment, comment_id)
def comment_get_all(story_id=None):
return _entity_get_all(models.Comment, story_id=story_id, is_active=True)
def comment_create(values):
return _entity_create(models.Comment, values)
def comment_update(comment_id, values):
return _entity_update(models.Comment, comment_id, values)
def comment_delete(comment_id):
comment = comment_get(comment_id)
if comment:
comment.is_active = False
_entity_update(models.Task, comment_id, comment.as_dict())
# BEGIN Tasks
def task_get(task_id):
return _entity_get(models.Task, task_id)
def task_get_all(marker=None, limit=None, story_id=None):
return _entity_get_all(models.Task,
marker=marker,
limit=limit,
story_id=story_id,
is_active=True)
def task_get_count(story_id=None):
return _entity_get_count(models.Task, story_id=story_id, is_active=True)
def task_create(values):
return _entity_create(models.Task, values)
def task_update(task_id, values):
return _entity_update(models.Task, task_id, values)
def task_delete(task_id):
task = task_get(task_id)
if task:
task.is_active = False
_entity_update(models.Task, task_id, task.as_dict())
# BEGIN authorization api
def authorization_code_get(code):
query = model_query(models.AuthorizationCode, get_session())
return query.filter_by(code=code, is_active=True).first()
def authorization_code_save(values):
return _entity_create(models.AuthorizationCode, values)
def authorization_code_delete(code):
del_code = authorization_code_get(code)
if del_code:
del_code.is_active = False
_entity_update(models.AuthorizationCode, del_code.id,
del_code.as_dict())
def token_get(access_token):
query = model_query(models.BearerToken, get_session())
# Note: is_active filtering for a reason, because we may still need to
# fetch expired token, for example to check refresh token info
return query.filter_by(access_token=access_token).first()
def token_save(values):
return _entity_create(models.BearerToken, values)
def token_update(access_token, values):
upd_token = token_get(access_token)
if upd_token:
return _entity_update(models.BearerToken, upd_token.id, values)
def token_delete(access_token):
del_token = token_get(access_token)
if del_token:
del_token.is_active = False
_entity_update(models.BearerToken, del_token.id,
del_token.as_dict())

View File

63
storyboard/db/api/auth.py Normal file
View File

@ -0,0 +1,63 @@
# Copyright (c) 2014 Mirantis Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from storyboard.db.api import base as api_base
from storyboard.db import models
def authorization_code_get(code):
query = api_base.model_query(models.AuthorizationCode,
api_base.get_session())
return query.filter_by(code=code, is_active=True).first()
def authorization_code_save(values):
return api_base.entity_create(models.AuthorizationCode, values)
def authorization_code_delete(code):
del_code = authorization_code_get(code)
if del_code:
del_code.is_active = False
api_base.entity_update(models.AuthorizationCode, del_code.id,
del_code.as_dict())
def token_get(access_token):
query = api_base.model_query(models.BearerToken, api_base.get_session())
# Note: is_active filtering for a reason, because we may still need to
# fetch expired token, for example to check refresh token info
return query.filter_by(access_token=access_token).first()
def token_save(values):
return api_base.entity_create(models.BearerToken, values)
def token_update(access_token, values):
upd_token = token_get(access_token)
if upd_token:
return api_base.entity_update(models.BearerToken, upd_token.id, values)
def token_delete(access_token):
del_token = token_get(access_token)
if del_token:
del_token.is_active = False
api_base.entity_update(models.BearerToken, del_token.id,
del_token.as_dict())

191
storyboard/db/api/base.py Normal file
View File

@ -0,0 +1,191 @@
# Copyright (c) 2014 Mirantis Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from oslo.config import cfg
import six
from storyboard.common import exception as exc
from storyboard.db import models
from storyboard.openstack.common.db import exception as db_exc
from storyboard.openstack.common.db.sqlalchemy import session as db_session
from storyboard.openstack.common.db.sqlalchemy.utils import paginate_query
from storyboard.openstack.common import log
CONF = cfg.CONF
CONF.import_group("database", "storyboard.openstack.common.db.options")
LOG = log.getLogger(__name__)
_FACADE = None
BASE = models.Base
def setup_db():
try:
engine = get_engine()
BASE.metadata.create_all(engine)
except Exception as e:
LOG.exception("Database registration exception: %s", e)
return False
return True
def drop_db():
try:
BASE.metadata.drop_all(get_engine())
except Exception as e:
LOG.exception("Database shutdown exception: %s", e)
return False
return True
def _get_facade_instance():
"""Generate an instance of the DB Facade.
"""
global _FACADE
if _FACADE is None:
_FACADE = db_session.EngineFacade(
CONF.database.connection,
**dict(CONF.database.iteritems()))
return _FACADE
def _destroy_facade_instance():
"""Destroys the db facade instance currently in use.
"""
global _FACADE
_FACADE = None
def get_engine():
"""Returns the global instance of our database engine.
"""
facade = _get_facade_instance()
return facade.get_engine()
def get_session(autocommit=True, expire_on_commit=False):
"""Returns a database session from our facade.
"""
facade = _get_facade_instance()
return facade.get_session(autocommit=autocommit,
expire_on_commit=expire_on_commit)
def cleanup():
"""Manually clean up our database engine.
"""
_destroy_facade_instance()
def model_query(model, session=None):
"""Query helper.
:param model: base model to query
"""
session = session or get_session()
query = session.query(model)
return query
def __entity_get(kls, entity_id, session):
query = model_query(kls, session)
return query.filter_by(id=entity_id, is_active=True).first()
def entity_get(kls, entity_id, filter_non_public=False):
entity = __entity_get(kls, entity_id, get_session())
if filter_non_public:
entity = _filter_non_public_fields(entity, entity._public_fields)
return entity
def entity_get_all(kls, filter_non_public=False, marker=None, limit=None,
**kwargs):
# Sanity check on input parameters
kwargs = dict((k, v) for k, v in kwargs.iteritems() if v)
# Construct the query
query = model_query(kls).filter_by(**kwargs)
query = paginate_query(query=query,
model=kls,
limit=limit,
sort_keys=['id'],
marker=marker,
sort_dir='asc')
# Execute the query
entities = query.all()
if len(entities) > 0 and filter_non_public:
sample_entity = entities[0] if len(entities) > 0 else None
public_fields = getattr(sample_entity, "_public_fields", [])
entities = [_filter_non_public_fields(entity, public_fields)
for entity in entities]
return entities
def entity_get_count(kls, **kwargs):
kwargs = dict((k, v) for k, v in kwargs.iteritems() if v)
count = model_query(kls).filter_by(**kwargs).count()
return count
def _filter_non_public_fields(entity, public_list=list()):
ent_copy = copy.copy(entity)
for attr_name, val in six.iteritems(entity.__dict__):
if attr_name.startswith("_"):
continue
if attr_name not in public_list:
delattr(ent_copy, attr_name)
return ent_copy
def entity_create(kls, values):
entity = kls()
entity.update(values.copy())
session = get_session()
with session.begin():
try:
session.add(entity)
except db_exc.DBDuplicateEntry:
raise exc.DuplicateEntry("Duplicate entry for : %s"
% kls.__name__)
return entity
def entity_update(kls, entity_id, values):
session = get_session()
with session.begin():
entity = __entity_get(kls, entity_id, session)
if entity is None:
raise exc.NotFound("%s %s not found" % (kls.__name__, entity_id))
entity.update(values.copy())
session.add(entity)
return entity

View File

@ -0,0 +1,42 @@
# Copyright (c) 2014 Mirantis Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from storyboard.db.api import base as api_base
from storyboard.db import models
def comment_get(comment_id):
return api_base.entity_get(models.Comment, comment_id)
def comment_get_all(story_id=None):
return api_base.entity_get_all(models.Comment, story_id=story_id,
is_active=True)
def comment_create(values):
return api_base.entity_create(models.Comment, values)
def comment_update(comment_id, values):
return api_base.entity_update(models.Comment, comment_id, values)
def comment_delete(comment_id):
comment = comment_get(comment_id)
if comment:
comment.is_active = False
api_base.entity_update(models.Task, comment_id, comment.as_dict())

View File

@ -0,0 +1,49 @@
# Copyright (c) 2014 Mirantis Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from storyboard.db.api import base as api_base
from storyboard.db import models
def project_get(project_id):
return api_base.entity_get(models.Project, project_id)
def project_get_all(marker=None, limit=None, **kwargs):
return api_base.entity_get_all(models.Project,
is_active=True,
marker=marker,
limit=limit,
**kwargs)
def project_get_count(**kwargs):
return api_base.entity_get_count(models.Project, is_active=True, **kwargs)
def project_create(values):
return api_base.entity_create(models.Project, values)
def project_update(project_id, values):
return api_base.entity_update(models.Project, project_id, values)
def project_delete(project_id):
project = project_get(project_id)
if project:
project.is_active = False
api_base.entity_update(models.Project, project_id, project.as_dict())

View File

@ -0,0 +1,92 @@
# Copyright (c) 2014 Mirantis Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from storyboard.db.api import base as api_base
from storyboard.db import models
def story_get(story_id):
return api_base.entity_get(models.Story, story_id)
def story_get_all(marker=None, limit=None, project_id=None):
if project_id:
return _story_get_all_in_project(marker=marker,
limit=limit,
project_id=project_id)
else:
return api_base.entity_get_all(models.Story, is_active=True,
marker=marker, limit=limit)
def story_get_count(project_id=None):
if project_id:
return _story_get_count_in_project(project_id)
else:
return api_base.entity_get_count(models.Story, is_active=True)
def _story_get_all_in_project(project_id, marker=None, limit=None):
session = api_base.get_session()
sub_query = api_base.model_query(models.Task.story_id, session) \
.filter_by(project_id=project_id, is_active=True) \
.distinct(True) \
.subquery()
query = api_base.model_query(models.Story, session) \
.filter_by(is_active=True) \
.join(sub_query, models.Story.tasks)
query = api_base.paginate_query(query=query,
model=models.Story,
limit=limit,
sort_keys=['id'],
marker=marker,
sort_dir='asc')
return query.all()
def _story_get_count_in_project(project_id):
session = api_base.get_session()
sub_query = api_base.model_query(models.Task.story_id, session) \
.filter_by(project_id=project_id, is_active=True) \
.distinct(True) \
.subquery()
query = api_base.model_query(models.Story, session) \
.filter_by(is_active=True) \
.join(sub_query, models.Story.tasks)
return query.count()
def story_create(values):
return api_base.entity_create(models.Story, values)
def story_update(story_id, values):
return api_base.entity_update(models.Story, story_id, values)
def story_delete(story_id):
story = story_get(story_id)
if story:
story.is_active = False
api_base.entity_update(models.Story, story_id, story.as_dict())

View File

@ -0,0 +1,50 @@
# Copyright (c) 2014 Mirantis Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from storyboard.db.api import base as api_base
from storyboard.db import models
def task_get(task_id):
return api_base.entity_get(models.Task, task_id)
def task_get_all(marker=None, limit=None, story_id=None):
return api_base.entity_get_all(models.Task,
marker=marker,
limit=limit,
story_id=story_id,
is_active=True)
def task_get_count(story_id=None):
return api_base.entity_get_count(models.Task, story_id=story_id,
is_active=True)
def task_create(values):
return api_base.entity_create(models.Task, values)
def task_update(task_id, values):
return api_base.entity_update(models.Task, task_id, values)
def task_delete(task_id):
task = task_get(task_id)
if task:
task.is_active = False
api_base.entity_update(models.Task, task_id, task.as_dict())

View File

@ -0,0 +1,61 @@
# Copyright (c) 2014 Mirantis Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from storyboard.common import exception as exc
from storyboard.db.api import base as api_base
from storyboard.db import models
from storyboard.openstack.common.db import exception as db_exc
def user_get(user_id, filter_non_public=False):
entity = api_base.entity_get(models.User, user_id,
filter_non_public=filter_non_public)
return entity
def user_get_all(marker=None, limit=None, filter_non_public=False):
return api_base.entity_get_all(models.User,
marker=marker,
limit=limit,
filter_non_public=filter_non_public)
def user_get_count():
return api_base.entity_get_count(models.User)
def user_get_by_openid(openid):
query = api_base.model_query(models.User, api_base.get_session())
return query.filter_by(openid=openid).first()
def user_create(values):
user = models.User()
user.update(values.copy())
session = api_base.get_session()
with session.begin():
try:
user.save(session=session)
except db_exc.DBDuplicateEntry as e:
raise exc.DuplicateEntry("Duplicate entry for User: %s"
% e.columns)
return user
def user_update(user_id, values):
return api_base.entity_update(models.User, user_id, values)

View File

@ -19,7 +19,7 @@ import yaml
from oslo.config import cfg
from sqlalchemy.exc import SADeprecationWarning
from storyboard.db import api as db_api
from storyboard.db.api import base as db_api
from storyboard.db.models import Project
from storyboard.db.models import ProjectGroup

View File

@ -18,7 +18,7 @@ import yaml
from sqlalchemy.exc import SADeprecationWarning
from storyboard.db import api as db_api
from storyboard.db.api import base as db_api
from storyboard.db.models import User
warnings.simplefilter("ignore", SADeprecationWarning)

View File

@ -24,7 +24,7 @@ import pecan.testing
import testtools
from storyboard.api.auth import authorization_checks
from storyboard.db import api as db_api
from storyboard.db.api import base as db_api_base
from storyboard.openstack.common import lockutils
from storyboard.openstack.common import log as logging
@ -95,12 +95,12 @@ class DbTestCase(TestCase):
def setup_db(self):
CONF.set_default('connection', "sqlite://", group='database')
db_api.setup_db()
db_api_base.setup_db()
self.addCleanup(self._drop_db)
def _drop_db(self):
db_api.drop_db()
db_api.cleanup()
db_api_base.drop_db()
db_api_base.cleanup()
PATH_PREFIX = '/v1'

View File

@ -15,7 +15,11 @@
from datetime import datetime
from storyboard.db import api as db_api
from storyboard.db.api import auth
from storyboard.db.api import comments
from storyboard.db.api import projects
from storyboard.db.api import stories
from storyboard.db.api import tasks
from storyboard.tests import base
@ -52,7 +56,7 @@ class ProjectsTest(BaseDbTestCase):
}
def test_save_project(self):
self._test_create(self.project_01, db_api.project_create)
self._test_create(self.project_01, projects.project_create)
def test_update_project(self):
delta = {
@ -60,7 +64,7 @@ class ProjectsTest(BaseDbTestCase):
'description': u'New Description'
}
self._test_update(self.project_01, delta,
db_api.project_create, db_api.project_update)
projects.project_create, projects.project_update)
class StoriesTest(BaseDbTestCase):
@ -74,7 +78,7 @@ class StoriesTest(BaseDbTestCase):
}
def test_create_story(self):
self._test_create(self.story_01, db_api.story_create)
self._test_create(self.story_01, stories.story_create)
def test_update_story(self):
delta = {
@ -82,7 +86,7 @@ class StoriesTest(BaseDbTestCase):
'description': u'New Description'
}
self._test_update(self.story_01, delta,
db_api.story_create, db_api.story_update)
stories.story_create, stories.story_update)
class TasksTest(BaseDbTestCase):
@ -97,7 +101,7 @@ class TasksTest(BaseDbTestCase):
}
def test_create_task(self):
self._test_create(self.task_01, db_api.task_create)
self._test_create(self.task_01, tasks.task_create)
def test_update_task(self):
delta = {
@ -106,7 +110,7 @@ class TasksTest(BaseDbTestCase):
}
self._test_update(self.task_01, delta,
db_api.task_create, db_api.task_update)
tasks.task_create, tasks.task_update)
class CommentsTest(BaseDbTestCase):
@ -120,7 +124,7 @@ class CommentsTest(BaseDbTestCase):
}
def test_create_comment(self):
self._test_create(self.comment_01, db_api.comment_create)
self._test_create(self.comment_01, comments.comment_create)
def test_update_comment(self):
delta = {
@ -128,7 +132,7 @@ class CommentsTest(BaseDbTestCase):
}
self._test_update(self.comment_01, delta,
db_api.comment_create, db_api.comment_update)
comments.comment_create, comments.comment_update)
class AuthorizationCodeTest(BaseDbTestCase):
@ -143,17 +147,17 @@ class AuthorizationCodeTest(BaseDbTestCase):
}
def test_create_code(self):
self._test_create(self.code_01, db_api.authorization_code_save)
self._test_create(self.code_01, auth.authorization_code_save)
def test_delete_code(self):
created_code = db_api.authorization_code_save(self.code_01)
created_code = auth.authorization_code_save(self.code_01)
self.assertIsNotNone(created_code,
"Could not create an Authorization code")
db_api.authorization_code_delete(created_code.code)
auth.authorization_code_delete(created_code.code)
fetched_code = db_api.authorization_code_get(created_code.code)
fetched_code = auth.authorization_code_get(created_code.code)
self.assertIsNone(fetched_code)
@ -171,16 +175,16 @@ class TokenTest(BaseDbTestCase):
}
def test_create_token(self):
self._test_create(self.token_01, db_api.token_save)
self._test_create(self.token_01, auth.token_save)
def test_delete_token(self):
created_token = db_api.token_save(self.token_01)
created_token = auth.token_save(self.token_01)
self.assertIsNotNone(created_token, "Could not create a Token")
db_api.token_delete(created_token.access_token)
auth.token_delete(created_token.access_token)
fetched_token = db_api.token_get(created_token.access_token)
fetched_token = auth.token_get(created_token.access_token)
self.assertIsNotNone(fetched_token,
"Could not fetch a non-active Token")
self.assertFalse(fetched_token.is_active,

View File

@ -16,9 +16,9 @@
import sys
import mock
from storyboard.db import api
import testscenarios
from storyboard.db.api import base as db_api_base
from storyboard.db.migration import cli
from storyboard.db import models
from storyboard.tests import base
@ -36,7 +36,7 @@ class TestLoadProjects(base.FunctionalTest):
def test_cli(self):
with mock.patch.object(sys, 'argv', self.argv):
cli.main()
session = api.get_session()
session = db_api_base.get_session()
project_groups = session.query(models.ProjectGroup).all()
projects = session.query(models.Project).all()
@ -53,7 +53,7 @@ class TestLoadProjects(base.FunctionalTest):
# call again and nothing should change
cli.main()
session = api.get_session()
session = db_api_base.get_session()
projects = session.query(models.Project).all()
self.assertIsNotNone(projects)