diff --git a/storyboard/plugin/event_worker.py b/storyboard/plugin/event_worker.py index 663cb627..e211dc6c 100644 --- a/storyboard/plugin/event_worker.py +++ b/storyboard/plugin/event_worker.py @@ -13,16 +13,15 @@ # limitations under the License. import abc -import signal - from multiprocessing import Process -from oslo_log import log +import signal from threading import Timer from oslo.config import cfg +from oslo_log import log import storyboard.db.api.base as db_api -import storyboard.db.models as models +from storyboard.notifications.notification_hook import class_mappings from storyboard.notifications.subscriber import subscribe from storyboard.openstack.common.gettextutils import _LI, _LW # noqa from storyboard.plugin.base import PluginBase @@ -172,10 +171,7 @@ class WorkerTaskBase(PluginBase): session = db_api.get_session(in_request=False, autocommit=False) with session.begin(subtransactions=True): - - author = db_api.entity_get(models.User, - author_id, - session=session) + author = self.resolve_resource_by_name(session, 'user', author_id) self.handle(session=session, author=author, @@ -189,6 +185,13 @@ class WorkerTaskBase(PluginBase): resource_before=resource_before, resource_after=resource_after) + def resolve_resource_by_name(self, session, resource_name, resource_id): + if resource_name not in class_mappings: + return None + + klass = class_mappings[resource_name][0] + return db_api.entity_get(klass, resource_id, session=session) + @abc.abstractmethod def handle(self, session, author, method, path, status, resource, resource_id, sub_resource=None, sub_resource_id=None, diff --git a/storyboard/tests/mock_data.py b/storyboard/tests/mock_data.py index 446dbbca..00952e55 100644 --- a/storyboard/tests/mock_data.py +++ b/storyboard/tests/mock_data.py @@ -26,6 +26,7 @@ from storyboard.db.models import ProjectGroup from storyboard.db.models import Story from storyboard.db.models import Subscription from storyboard.db.models import Task +from storyboard.db.models import Team from storyboard.db.models import TimeLineEvent from storyboard.db.models import User @@ -319,15 +320,29 @@ def load(): # Load some milestones load_data([ Milestone( + id=1, name='test_milestone_01', branch_id=1 ), Milestone( + id=2, name='test_milestone_02', branch_id=2 ) ]) + # Load some teams + load_data([ + Team( + id=1, + name='test_team_1' + ), + Team( + id=2, + name='test_team_2' + ) + ]) + def load_data(data): """Pre load test data into the database. diff --git a/storyboard/tests/plugin/test_event_worker.py b/storyboard/tests/plugin/test_event_worker.py new file mode 100644 index 00000000..a183852f --- /dev/null +++ b/storyboard/tests/plugin/test_event_worker.py @@ -0,0 +1,74 @@ +# Copyright (c) 2014 Hewlett-Packard Development Company, L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing permissions and +# limitations under the License. + +import storyboard.db.api.base as db_api_base +import storyboard.plugin.event_worker as plugin_base +import storyboard.tests.base as base + + +class TestWorkerTaskBase(base.FunctionalTest): + def setUp(self): + super(TestWorkerTaskBase, self).setUp() + + def test_resolve_by_name(self): + '''Assert that resolve_resource_by_name works.''' + + worker = TestWorkerPlugin({}) + + with base.HybridSessionManager(): + session = db_api_base.get_session() + + task = worker.resolve_resource_by_name(session, 'task', 1) + self.assertIsNotNone(task) + self.assertEqual(1, task.id) + + project_group = worker.resolve_resource_by_name(session, + 'project_group', 1) + self.assertIsNotNone(project_group) + self.assertEqual(1, project_group.id) + + project = worker.resolve_resource_by_name(session, 'project', 1) + self.assertIsNotNone(project) + self.assertEqual(1, project.id) + + user = worker.resolve_resource_by_name(session, 'user', 1) + self.assertIsNotNone(user) + self.assertEqual(1, user.id) + + team = worker.resolve_resource_by_name(session, 'team', 1) + self.assertIsNotNone(team) + self.assertEqual(1, team.id) + + story = worker.resolve_resource_by_name(session, 'story', 1) + self.assertIsNotNone(story) + self.assertEqual(1, story.id) + + branch = worker.resolve_resource_by_name(session, 'branch', 1) + self.assertIsNotNone(branch) + self.assertEqual(1, branch.id) + + milestone = worker.resolve_resource_by_name(session, + 'milestone', 1) + self.assertIsNotNone(milestone) + self.assertEqual(1, milestone.id) + + +class TestWorkerPlugin(plugin_base.WorkerTaskBase): + def handle(self, session, author, method, path, status, resource, + resource_id, sub_resource=None, sub_resource_id=None, + resource_before=None, resource_after=None): + pass + + def enabled(self): + return True