diff --git a/stacktach/db.py b/stacktach/db.py index d28d723..9dd525f 100644 --- a/stacktach/db.py +++ b/stacktach/db.py @@ -16,6 +16,10 @@ def _safe_get(Model, **kwargs): return object +def get_deployment(id): + return _safe_get(models.Deployment, id=id) + + def get_or_create_deployment(name): return models.Deployment.objects.get_or_create(name=name) @@ -132,4 +136,4 @@ def get_image_delete(**kwargs): def get_image_usage(**kwargs): - return _safe_get(models.ImageUsage, **kwargs) \ No newline at end of file + return _safe_get(models.ImageUsage, **kwargs) diff --git a/tests/unit/test_worker.py b/tests/unit/test_worker.py index a03a080..4e594ee 100644 --- a/tests/unit/test_worker.py +++ b/tests/unit/test_worker.py @@ -168,10 +168,10 @@ class ConsumerTestCase(StacktachBaseTestCase): "services": ["nova"], "topics": {"nova": self._test_topics()} } - self.mox.StubOutWithMock(db, 'get_or_create_deployment') + self.mox.StubOutWithMock(db, 'get_deployment') deployment = self.mox.CreateMockAnything() - db.get_or_create_deployment(config['name'])\ - .AndReturn((deployment, True)) + deployment.id = 1 + db.get_deployment(deployment.id).AndReturn(deployment) self.mox.StubOutWithMock(kombu.connection, 'BrokerConnection') params = dict(hostname=config['rabbit_host'], port=config['rabbit_port'], @@ -193,7 +193,7 @@ class ConsumerTestCase(StacktachBaseTestCase): consumer.run() worker.continue_running().AndReturn(False) self.mox.ReplayAll() - worker.run(config, exchange) + worker.run(config, deployment.id, exchange) self.mox.VerifyAll() def test_run_queue_args(self): @@ -210,10 +210,10 @@ class ConsumerTestCase(StacktachBaseTestCase): "services": ["nova"], "topics": {"nova": self._test_topics()} } - self.mox.StubOutWithMock(db, 'get_or_create_deployment') + self.mox.StubOutWithMock(db, 'get_deployment') deployment = self.mox.CreateMockAnything() - db.get_or_create_deployment(config['name'])\ - .AndReturn((deployment, True)) + deployment.id = 1 + db.get_deployment(deployment.id).AndReturn(deployment) self.mox.StubOutWithMock(kombu.connection, 'BrokerConnection') params = dict(hostname=config['rabbit_host'], port=config['rabbit_port'], @@ -236,5 +236,5 @@ class ConsumerTestCase(StacktachBaseTestCase): consumer.run() worker.continue_running().AndReturn(False) self.mox.ReplayAll() - worker.run(config, exchange) + worker.run(config, deployment.id, exchange) self.mox.VerifyAll() diff --git a/worker/start_workers.py b/worker/start_workers.py index 0558b57..a6a5e20 100644 --- a/worker/start_workers.py +++ b/worker/start_workers.py @@ -9,6 +9,9 @@ POSSIBLE_TOPDIR = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]), if os.path.exists(os.path.join(POSSIBLE_TOPDIR, 'stacktach')): sys.path.insert(0, POSSIBLE_TOPDIR) +from stacktach import db +from django.db import close_connection + import worker.worker as worker from worker import config @@ -30,8 +33,15 @@ if __name__ == '__main__': for deployment in config.deployments(): if deployment.get('enabled', True): + db_deployment, new = db.get_or_create_deployment(deployment['name']) + # NOTE (apmelton) + # Close the connection before spinning up the child process, + # otherwise the child process will attempt to use the connection + # the parent process opened up to get/create the deployment. + close_connection() for exchange in deployment.get('topics').keys(): process = Process(target=worker.run, args=(deployment, + db_deployment.id, exchange,)) process.daemon = True process.start() diff --git a/worker/worker.py b/worker/worker.py index 1d20da8..e646788 100644 --- a/worker/worker.py +++ b/worker/worker.py @@ -142,7 +142,7 @@ def exit_or_sleep(exit=False): time.sleep(5) -def run(deployment_config, exchange): +def run(deployment_config, deployment_id, exchange): name = deployment_config['name'] host = deployment_config.get('rabbit_host', 'localhost') port = deployment_config.get('rabbit_port', 5672) @@ -154,7 +154,7 @@ def run(deployment_config, exchange): exit_on_exception = deployment_config.get('exit_on_exception', False) topics = deployment_config.get('topics', {}) - deployment, new = db.get_or_create_deployment(name) + deployment = db.get_deployment(deployment_id) print "Starting worker for '%s %s'" % (name, exchange) LOG.info("%s: %s %s %s %s %s" % (name, exchange, host, port, user_id,