commit 0084e5d1244e9d1c7a47dfdc8a0f3adf497ec96f Author: John Bresnahan Date: Wed May 15 11:40:57 2013 -1000 Initial Commit. WSGI is not yet in place. diff --git a/HACKING.rst b/HACKING.rst new file mode 100644 index 0000000..f37ffb3 --- /dev/null +++ b/HACKING.rst @@ -0,0 +1,256 @@ +Staccato Style Commandments +=========================== + +- Step 1: Read http://www.python.org/dev/peps/pep-0008/ +- Step 2: Read http://www.python.org/dev/peps/pep-0008/ again +- Step 3: Read on + + +General +------- +- Put two newlines between top-level code (funcs, classes, etc) +- Put one newline between methods in classes and anywhere else +- Do not write "except:", use "except Exception:" at the very least +- Include your name with TODOs as in "#TODO(termie)" +- Do not name anything the same name as a built-in or reserved word +- Use the "is not" operator when testing for unequal identities. Example:: + + if not X is Y: # BAD, intended behavior is ambiguous + pass + + if X is not Y: # OKAY, intuitive + pass + +- Use the "not in" operator for evaluating membership in a collection. Example:: + + if not X in Y: # BAD, intended behavior is ambiguous + pass + + if X not in Y: # OKAY, intuitive + pass + + if not (X in Y or X in Z): # OKAY, still better than all those 'not's + pass + + +Imports +------- +- Do not make relative imports +- Order your imports by the full module path +- Organize your imports according to the following template + +Example:: + + # vim: tabstop=4 shiftwidth=4 softtabstop=4 + {{stdlib imports in human alphabetical order}} + \n + {{third-party lib imports in human alphabetical order}} + \n + {{staccato imports in human alphabetical order}} + \n + \n + {{begin your code}} + + +Human Alphabetical Order Examples +--------------------------------- +Example:: + + import httplib + import logging + import random + import StringIO + import time + import unittest + + import eventlet + import webob.exc + + import staccato.api.middleware + from staccato.api import images + from staccato.auth import users + import staccato.common + from staccato.endpoint import cloud + from staccato import test + + +Docstrings +---------- + +Docstrings are required for all functions and methods. + +Docstrings should ONLY use triple-double-quotes (``"""``) + +Single-line docstrings should NEVER have extraneous whitespace +between enclosing triple-double-quotes. + +**INCORRECT** :: + + """ There is some whitespace between the enclosing quotes :( """ + +**CORRECT** :: + + """There is no whitespace between the enclosing quotes :)""" + +Docstrings that span more than one line should look like this: + +Example:: + + """ + Start the docstring on the line following the opening triple-double-quote + + If you are going to describe parameters and return values, use Sphinx, the + appropriate syntax is as follows. + + :param foo: the foo parameter + :param bar: the bar parameter + :returns: return_type -- description of the return value + :returns: description of the return value + :raises: AttributeError, KeyError + """ + +**DO NOT** leave an extra newline before the closing triple-double-quote. + + +Dictionaries/Lists +------------------ +If a dictionary (dict) or list object is longer than 80 characters, its items +should be split with newlines. Embedded iterables should have their items +indented. Additionally, the last item in the dictionary should have a trailing +comma. This increases readability and simplifies future diffs. + +Example:: + + my_dictionary = { + "image": { + "name": "Just a Snapshot", + "size": 2749573, + "properties": { + "user_id": 12, + "arch": "x86_64", + }, + "things": [ + "thing_one", + "thing_two", + ], + "status": "ACTIVE", + }, + } + + +Calling Methods +--------------- +Calls to methods 80 characters or longer should format each argument with +newlines. This is not a requirement, but a guideline:: + + unnecessarily_long_function_name('string one', + 'string two', + kwarg1=constants.ACTIVE, + kwarg2=['a', 'b', 'c']) + + +Rather than constructing parameters inline, it is better to break things up:: + + list_of_strings = [ + 'what_a_long_string', + 'not as long', + ] + + dict_of_numbers = { + 'one': 1, + 'two': 2, + 'twenty four': 24, + } + + object_one.call_a_method('string three', + 'string four', + kwarg1=list_of_strings, + kwarg2=dict_of_numbers) + + +Internationalization (i18n) Strings +----------------------------------- +In order to support multiple languages, we have a mechanism to support +automatic translations of exception and log strings. + +Example:: + + msg = _("An error occurred") + raise HTTPBadRequest(explanation=msg) + +If you have a variable to place within the string, first internationalize the +template string then do the replacement. + +Example:: + + msg = _("Missing parameter: %s") % ("flavor",) + LOG.error(msg) + +If you have multiple variables to place in the string, use keyword parameters. +This helps our translators reorder parameters when needed. + +Example:: + + msg = _("The server with id %(s_id)s has no key %(m_key)s") + LOG.error(msg % {"s_id": "1234", "m_key": "imageId"}) + + +Creating Unit Tests +------------------- +For every new feature, unit tests should be created that both test and +(implicitly) document the usage of said feature. If submitting a patch for a +bug that had no unit test, a new passing unit test should be added. If a +submitted bug fix does have a unit test, be sure to add a new one that fails +without the patch and passes with the patch. + + +Commit Messages +--------------- +Using a common format for commit messages will help keep our git history +readable. Follow these guidelines: + + First, provide a brief summary of 50 characters or less. Summaries + of greater then 72 characters will be rejected by the gate. + + The first line of the commit message should provide an accurate + description of the change, not just a reference to a bug or + blueprint. It must be followed by a single blank line. + + Following your brief summary, provide a more detailed description of + the patch, manually wrapping the text at 72 characters. This + description should provide enough detail that one does not have to + refer to external resources to determine its high-level functionality. + + Once you use 'git review', two lines will be appended to the commit + message: a blank line followed by a 'Change-Id'. This is important + to correlate this commit with a specific review in Gerrit, and it + should not be modified. + +For further information on constructing high quality commit messages, +and how to split up commits into a series of changes, consult the +project wiki: + + http://wiki.openstack.org/GitCommitMessages + + +openstack-common +---------------- + +A number of modules from openstack-common are imported into the project. + +These modules are "incubating" in openstack-common and are kept in sync +with the help of openstack-common's update.py script. See: + + http://wiki.openstack.org/CommonLibrary#Incubation + +The copy of the code should never be directly modified here. Please +always update openstack-common first and then run the script to copy +the changes across. + + +Logging +------- +Use __name__ as the name of your logger and name your module-level logger +objects 'LOG':: + + LOG = logging.getLogger(__name__) diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..68c771a --- /dev/null +++ b/LICENSE @@ -0,0 +1,176 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + diff --git a/README.rst b/README.rst new file mode 100644 index 0000000..89f13a1 --- /dev/null +++ b/README.rst @@ -0,0 +1,16 @@ +======== +Staccato +======== + +Staccato is the name given to lightning which appears as a single very +bright, short-duration stroke that often has considerable branching. +This service is intended to transfer images from cloud to cloud quickly +and with the option of multicast. + +Staccato is a data transfer service. Its primary purpose is to transfer +images from storage repository to compute node, for booting, or storage +to storage. + + +* `Official Staccato documentation <> + diff --git a/etc/staccato-api-paste.ini b/etc/staccato-api-paste.ini new file mode 100644 index 0000000..4e5a08c --- /dev/null +++ b/etc/staccato-api-paste.ini @@ -0,0 +1,17 @@ +[pipeline:staccato-api] +pipeline = unauthenticated-context rootapp + +[app:rootapp] +use = egg:Paste#urlmap +/: apiversions +/v1: apiv1app + +[app:apiversions] +paste.app_factory = staccato.api.versions:create_resource + +[app:apiv1app] +paste.app_factory = staccato.api.v1.xfer:create_resource + + +[filter:unauthenticated-context] +paste.filter_factory = staccato.wsgi:Middleware.factory diff --git a/etc/staccato-api.conf b/etc/staccato-api.conf new file mode 100644 index 0000000..e01a843 --- /dev/null +++ b/etc/staccato-api.conf @@ -0,0 +1,62 @@ +[DEFAULT] +# Show more verbose log output (sets INFO log level output) +#verbose = False + +# Show debugging output in logs (sets DEBUG log level output) +#debug = False + +#known_protocols = staccato.protocol.file, +# staccato.protocol.http, + +# Address to bind the API server +bind_host = 0.0.0.0 + +# Port the bind the API server to +bind_port = 9292 + +# Log to this file. Make sure you do not set the same log +# file for both the API and registry servers! +log_file = /var/log/staccato/api.log + +# Backlog requests when creating socket +backlog = 4096 + +# SQLAlchemy connection string for the reference implementation +# registry server. Any valid SQLAlchemy connection string is fine. +# See: http://www.sqlalchemy.org/docs/05/reference/sqlalchemy/connections.html#sqlalchemy.create_engine +sql_connection = sqlite:///glance.sqlite + +# Period in seconds after which SQLAlchemy should reestablish its connection +# to the database. +# +# MySQL uses a default `wait_timeout` of 8 hours, after which it will drop +# idle connections. This can result in 'MySQL Gone Away' exceptions. If you +# notice this, you can lower this value to ensure that SQLAlchemy reconnects +# before MySQL can drop the connection. +sql_idle_timeout = 3600 + +# Role used to identify an authenticated user as administrator +#admin_role = admin + +# Allow unauthenticated users to access the API with read-only +# privileges. This only applies when using ContextMiddleware. +#allow_anonymous_access = False + +# ================= SSL Options =============================== + +# Certificate file to use when starting API server securely +#cert_file = /path/to/certfile + +# Private key file to use when starting API server securely +#key_file = /path/to/keyfile + +# CA certificate file to use to verify connecting clients +#ca_file = /path/to/cafile + +# ================= Security Options ========================== + +# AES key for encrypting store 'location' metadata, including +# -- if used -- Swift or S3 credentials +# Should be set to a random string of length 16, 24 or 32 bytes +#metadata_encryption_key = <16, 24 or 32 char registry metadata key> + diff --git a/etc/staccato-protocols.json b/etc/staccato-protocols.json new file mode 100644 index 0000000..7890de6 --- /dev/null +++ b/etc/staccato-protocols.json @@ -0,0 +1,3 @@ +{ + "file": [{"module": "staccato.protocols.file.FileProtocol"}] +} diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..54ae538 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,33 @@ +[build_sphinx] +all_files = 1 +build-dir = doc/build +source-dir = doc/source + +[egg_info] +tag_build = +tag_date = 0 +tag_svn_revision = 0 + +[compile_catalog] +directory = staccato/locale +domain = staccato + +[update_catalog] +domain = staccato +output_dir = staccato/locale +input_file = staccato/locale/staccato.pot + +[extract_messages] +keywords = _ gettext ngettext l_ lazy_gettext +mapping_file = babel.cfg +output_file = staccato/locale/staccato.pot + +[nosetests] +# NOTE(jkoelker) To run the test suite under nose install the following +# coverage http://pypi.python.org/pypi/coverage +# tissue http://pypi.python.org/pypi/tissue (pep8 checker) +# openstack-nose https://github.com/jkoelker/openstack-nose +verbosity=2 +cover-package = staccato +cover-html = true +cover-erase = true diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..81696fc --- /dev/null +++ b/setup.py @@ -0,0 +1,52 @@ +#!/usr/bin/python +# Copyright (c) 2010 OpenStack, LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import setuptools + +from staccato.openstack.common import setup + +requires = setup.parse_requirements() +depend_links = setup.parse_dependency_links() +project = 'staccato' + +setuptools.setup( + name=project, + version='0.1', + description='The Staccato project provides data transfer services ' + 'to OpenStack services and users. It is primarily used ' + 'for VM image propagation.', + license='Apache License (2.0)', + author='OpenStack', + author_email='openstack@lists.launchpad.net', + url='http://staccato.openstack.org/', + packages=setuptools.find_packages(exclude=['']), + test_suite='nose.collector', + cmdclass=setup.get_cmdclass(), + include_package_data=True, + install_requires=requires, + dependency_links=depend_links, + classifiers=[ + 'Development Status :: 4 - Beta', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: POSIX :: Linux', + 'Programming Language :: Python :: 2.6', + 'Environment :: No Input/Output (Daemon)', + 'Environment :: OpenStack', + ], + entry_points={'console_scripts': + ['staccato-api=staccato.cmd.api:main', + ]}, + py_modules=[]) diff --git a/staccato/__init__.py b/staccato/__init__.py new file mode 100644 index 0000000..aac579a --- /dev/null +++ b/staccato/__init__.py @@ -0,0 +1,20 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010-2011 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import gettext + +gettext.install('staccato', unicode=1) diff --git a/staccato/api/__init__.py b/staccato/api/__init__.py new file mode 100644 index 0000000..0b23214 --- /dev/null +++ b/staccato/api/__init__.py @@ -0,0 +1,5 @@ +import paste.urlmap + + +def root_app_factory(loader, global_conf, **local_conf): + return paste.urlmap.urlmap_factory(loader, global_conf, **local_conf) diff --git a/staccato/api/v1/__init__.py b/staccato/api/v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/staccato/api/v1/router.py b/staccato/api/v1/router.py new file mode 100644 index 0000000..e69de29 diff --git a/staccato/api/v1/xfer.py b/staccato/api/v1/xfer.py new file mode 100644 index 0000000..87025d1 --- /dev/null +++ b/staccato/api/v1/xfer.py @@ -0,0 +1,40 @@ +import httplib +import json + +import webob.dec +import webob.exc + +from staccato.common import wsgi + + +class XferApp(object): + """ + A single WSGI application that just returns version information + """ + def __init__(self, conf): + self.conf = conf + + def xfer(self, req): + required_params = ['srcurl', 'dsturl'] + optional_params = [] + + @webob.dec.wsgify(RequestClass=wsgi.Request) + def __call__(self, req): + version_info = { + 'id': self.conf.id, + 'version': self.conf.version, + 'status': 'active' + } + version_objs = [version_info] + + response = webob.Response(request=req, + status=httplib.MULTIPLE_CHOICES, + content_type='application/json') + response.body = json.dumps(dict(versions=version_objs)) + return response + + +def create_resource(conf): + # TODO: figure out what this has to be this way + config_obj = conf['CONF']['conf'] + return XferApp(conf=config_obj) diff --git a/staccato/api/versions.py b/staccato/api/versions.py new file mode 100644 index 0000000..6518ea2 --- /dev/null +++ b/staccato/api/versions.py @@ -0,0 +1,34 @@ +import httplib +import json +import webob + +from staccato.common import wsgi + + +class VersionApp(object): + """ + A single WSGI application that just returns version information + """ + def __init__(self, conf): + self.conf = conf + + @webob.dec.wsgify(RequestClass=wsgi.Request) + def __call__(self, req): + version_info = { + 'id': self.conf.id, + 'version': self.conf.version, + 'status': 'active' + } + version_objs = [version_info] + + response = webob.Response(request=req, + status=httplib.MULTIPLE_CHOICES, + content_type='application/json') + response.body = json.dumps(dict(versions=version_objs)) + return response + + +def create_resource(conf): + # TODO: figure out what this has to be this way + config_obj = conf['CONF']['conf'] + return VersionApp(conf=config_obj) diff --git a/staccato/cmd/__init__.py b/staccato/cmd/__init__.py new file mode 100644 index 0000000..c65c52d --- /dev/null +++ b/staccato/cmd/__init__.py @@ -0,0 +1,16 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. diff --git a/staccato/cmd/api.py b/staccato/cmd/api.py new file mode 100755 index 0000000..7308ff6 --- /dev/null +++ b/staccato/cmd/api.py @@ -0,0 +1,47 @@ +import eventlet +import gettext +import sys + +from staccato.common import utils +from staccato import wsgi + +# Monkey patch socket and time +eventlet.patcher.monkey_patch(all=False, socket=True, time=True) + +gettext.install('staccato', unicode=1) + + +def fail(returncode, e): + sys.stderr.write("ERROR: %s\n" % e) + sys.exit(returncode) + + +class CONF(object): + def __init__(self): + self.bind_host = "0.0.0.0" + self.bind_port = 9876 + self.cert_file = None + self.key_file = None + self.backlog = -1 + self.tcp_keepidle = False + self.id = "deadbeef" + self.version = "v1" + + +def main(): + try: + #config.parse_args(sys.argv) + conf = CONF() + wsgi_app = utils.load_paste_app( + 'staccato-api', + '/home/jbresnah/Dev/OpenStack/staccato/etc/glance-api-paste.ini', + {'conf': conf}) + + server = wsgi.Server(CONF=conf) + server.start(wsgi_app, default_port=9292) + server.wait() + except RuntimeError as e: + fail(1, e) + + +main() diff --git a/staccato/common/__init__.py b/staccato/common/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/staccato/common/config.py b/staccato/common/config.py new file mode 100644 index 0000000..7a96135 --- /dev/null +++ b/staccato/common/config.py @@ -0,0 +1,74 @@ +import logging +from oslo.config import cfg +import json + +from staccato.version import version_info as version + +common_opts = [ + cfg.ListOpt('protocol_plugins', + default=['staccato.protocols.file.FileProtocol', + ]), + cfg.StrOpt('sql_connection', + default='sqlite:///staccato.sqlite', + secret=True, + metavar='CONNECTION', + help='A valid SQLAlchemy connection string for the registry ' + 'database. Default: %(default)s'), + cfg.IntOpt('sql_idle_timeout', default=3600, + help=_('Period in seconds after which SQLAlchemy should ' + 'reestablish its connection to the database.')), + cfg.IntOpt('sql_max_retries', default=60, + help=_('The number of times to retry a connection to the SQL' + 'server.')), + cfg.IntOpt('sql_retry_interval', default=1, + help=_('The amount of time to wait (in seconds) before ' + 'attempting to retry the SQL connection.')), + cfg.BoolOpt('db_auto_create', default=False, + help=_('A boolean that determines if the database will be ' + 'automatically created.')), + cfg.BoolOpt('db_auto_create', default=False, + help=_('A boolean that determines if the database will be ' + 'automatically created.')), + + cfg.StrOpt('log_level', + default='INFO', + help='', + dest='str_log_level'), + cfg.StrOpt('protocol_policy', default='staccato-protocols.json', + help=''), +] + + +def _log_string_to_val(conf): + str_lvl = conf.str_log_level.lower() + + val = logging.INFO + if str_lvl == 'error': + val = logging.ERROR + elif str_lvl == 'warn' or str_lvl == 'warning': + val = logging.WARN + elif str_lvl == "DEBUG": + val = logging.DEBUG + setattr(conf, 'log_level', val) + + +def get_config_object(args=None, usage=None, default_config_files=None): + conf = cfg.ConfigOpts() + conf.register_opts(common_opts) + conf(args=args, + project='staccato', + version=version.cached_version_string(), + usage=usage, + default_config_files=default_config_files) + _log_string_to_val(conf) + + return conf + + +def get_protocol_policy(conf): + protocol_conf_file = conf.protocol_policy + if protocol_conf_file is None: + # TODO log a warning + return {} + policy = json.load(open(protocol_conf_file, 'r')) + return policy diff --git a/staccato/common/exceptions.py b/staccato/common/exceptions.py new file mode 100644 index 0000000..270d992 --- /dev/null +++ b/staccato/common/exceptions.py @@ -0,0 +1,36 @@ + + +class StaccatoBaseException(Exception): + pass + + +class StaccatoNotImplementedException(StaccatoBaseException): + pass + + +class StaccatoProtocolConnectionException(StaccatoBaseException): + pass + + +class StaccatoCancelException(StaccatoBaseException): + pass + + +class StaccatoIOException(StaccatoBaseException): + pass + + +class StaccatoParameterError(StaccatoBaseException): + pass + + +class StaccatoMisconfigurationException(StaccatoBaseException): + pass + + +class StaccatoDataBaseException(StaccatoBaseException): + pass + + +class StaccatoEventException(StaccatoBaseException): + pass diff --git a/staccato/common/state_machine.py b/staccato/common/state_machine.py new file mode 100644 index 0000000..401e80e --- /dev/null +++ b/staccato/common/state_machine.py @@ -0,0 +1,68 @@ +from staccato.common import exceptions + + +class StateMachine(object): + + def __init__(self): + # set up the transition table + self._transitions = {} + self._state_funcs = {} + + def set_state_func(self, state, func): + self._state_funcs[state] = func + + def set_mapping(self, state, event, next_state, func=None): + if state not in self._transitions: + self._transitions[state] = {} + + event_dict = self._transitions[state] + if event not in event_dict: + event_dict[event] = {} + + if func is None: + func = self._state_funcs[next_state] + + self._transitions[state][event] = (next_state, func) + + def _state_changed(self, current_state, event, new_state, **kwvals): + raise Exception("this needs to be implemented") + + def _get_current_state(self, **kwvals): + raise Exception("This needs to be implemented") + + def event_occurred(self, event, **kwvals): + + current_state = self._get_current_state(**kwvals) + if current_state not in self._transitions: + raise exceptions.StaccatoParameterError( + "Undefined event %s at state %s" % (event, current_state)) + state_ent = self._transitions[current_state] + if event not in state_ent: + raise exceptions.StaccatoParameterError( + "Undefined event %s at state %s" % (event, current_state)) + + next_state, function = state_ent[event] + + self._state_changed(current_state, event, next_state, **kwvals) + # log the change + if function: + try: + function(current_state, event, next_state, **kwvals) + except Exception, ex: + # TODO: deal with the exception in a sane way. we likely need + # to trigger an event signifying and error occured but we + # may not want to recurse + raise + + def mapping_to_digraph(self): + print 'digraph {' + for start_state in self._transitions: + for event in self._transitions[start_state]: + ent = self._transitions[start_state][event] + if ent is not None: + p_end_state = ent[0].replace("STATE_", '') + p_start_state = start_state.replace("STATE_", '') + p_event = event.replace("EVENT_", '') + print '%s -> %s [ label = "%s" ];'\ + % (p_start_state, p_end_state, p_event) + print '}' diff --git a/staccato/common/utils.py b/staccato/common/utils.py new file mode 100644 index 0000000..16098a9 --- /dev/null +++ b/staccato/common/utils.py @@ -0,0 +1,69 @@ +import logging +from paste import deploy +import re +from staccato.common import exceptions +from staccato.openstack.common import importutils + + +def not_implemented_decorator(func): + def call(self, *args, **kwargs): + def raise_error(func): + raise exceptions.StaccatoNotImplementedException( + "function %s must be implemented" % (func.func_name)) + return raise_error(func) + return call + + +def load_paste_app(app_name, conf_file, conf): + try: + logger = logging.getLogger(__name__) + logger.debug(_("Loading %(app_name)s from %(conf_file)s"), + {'conf_file': conf_file, 'app_name': app_name}) + + app = deploy.loadapp("config:%s" % conf_file, + name=app_name, + global_conf={'CONF': conf}) + + return app + except (LookupError, ImportError) as e: + msg = _("Unable to load %(app_name)s from " + "configuration file %(conf_file)s." + "\nGot: %(e)r") % locals() + logger.error(msg) + raise RuntimeError(msg) + + +def find_protocol_module_name(lookup_dict, url_parts): + if url_parts.scheme not in lookup_dict: + raise exceptions.StaccatoParameterError( + '%s protocol not found' % url_parts.scheme) + p_list = lookup_dict[url_parts.scheme] + + for entry in p_list: + match_keys = ['netloc', 'path', 'params', 'query'] + ndx = 0 + found = True + for k in match_keys: + ndx = ndx + 1 + if k in entry: + needle = url_parts[ndx] + haystack = entry[k] + found = re.match(haystack, needle) + if found: + return entry['module'] + + raise exceptions.StaccatoParameterError( + 'The url %s is not supported' % url_parts.geturl()) + + +def load_protocol_module(module_name, CONF): + try: + protocol_cls = importutils.import_class(module_name) + except ImportError, ie: + raise exceptions.StaccatoParameterError( + "The protocol module %s could not be loaded. %s" % + (module_name, ie)) + + protocol_instance = protocol_cls(CONF) + + return protocol_instance diff --git a/staccato/common/wsgi.py b/staccato/common/wsgi.py new file mode 100644 index 0000000..a403cef --- /dev/null +++ b/staccato/common/wsgi.py @@ -0,0 +1,23 @@ +import webob + + +class Request(webob.Request): + """Add some Openstack API-specific logic to the base webob.Request.""" + + def best_match_content_type(self): + """Determine the requested response content-type.""" + supported = ('application/json',) + bm = self.accept.best_match(supported) + return bm or 'application/json' + + def get_content_type(self, allowed_content_types): + """Determine content type of the request body.""" + if "Content-Type" not in self.headers: + raise exception.InvalidContentType(content_type=None) + + content_type = self.content_type + + if content_type not in allowed_content_types: + raise exception.InvalidContentType(content_type=content_type) + else: + return content_type diff --git a/staccato/db/__init__.py b/staccato/db/__init__.py new file mode 100644 index 0000000..2605293 --- /dev/null +++ b/staccato/db/__init__.py @@ -0,0 +1,174 @@ +import logging +import time + +import sqlalchemy +import sqlalchemy.orm as sa_orm +import sqlalchemy.sql.expression as sql_expression + +from staccato.db import migration, models +import staccato.openstack.common.log as os_logging +from staccato.common import exceptions +from staccato.db import models +import staccato.xfer.constants as constants + + +LOG = os_logging.getLogger(__name__) + + +class StaccatoDB(object): + + def __init__(self, CONF, autocommit=True, expire_on_commit=False): + self.CONF = CONF + self.engine = _get_db_object(CONF) + self.maker = sa_orm.sessionmaker(bind=self.engine, + autocommit=autocommit, + expire_on_commit=expire_on_commit) + + def get_sessions(self): + return self.maker() + + def get_new_xfer(self, + srcurl, + dsturl, + src_module_name, + dst_module_name, + start_ndx=0, + end_ndx=-1, + read_info=None, + write_info=None, + session=None): + + if session is None: + session = self.get_sessions() + + with session.begin(): + xfer_request = models.XferRequest() + xfer_request.srcurl = srcurl + xfer_request.dsturl = dsturl + xfer_request.src_module_name = src_module_name + xfer_request.dst_module_name = dst_module_name + xfer_request.start_ndx = start_ndx + xfer_request.next_ndx = start_ndx + xfer_request.end_ndx = end_ndx + xfer_request.state = "STATE_NEW" + + session.add(xfer_request) + session.flush() + + return xfer_request + + def save_db_obj(self, db_obj, session=None): + if session is None: + session = self.get_sessions() + + with session.begin(): + session.add(db_obj) + session.flush() + + def lookup_xfer_request_by_id(self, xfer_id, session=None): + if session is None: + session = self.get_sessions() + + with session.begin(): + query = session.query(models.XferRequest)\ + .filter(models.XferRequest.id == xfer_id) + xfer_request = query.one() + + return xfer_request + + def get_all_ready(self, limit=None, session=None): + if session is None: + session = self.get_sessions() + + with session.begin(): + query = session.query(models.XferRequest)\ + .filter( + sql_expression.or_(models.XferRequest.state == constants.State.STATE_NEW, + models.XferRequest.state == constants.State.STATE_ERROR)) + if limit is not None: + query = query.limit(limit) + xfer_requests = query.all() + return xfer_requests + + def delete_db_obj(self, db_obj, session=None): + if session is None: + session = self.get_sessions() + + with session.begin(): + session.delete(db_obj) + session.flush() + + +def _get_db_object(CONF): + sa_logger = logging.getLogger('sqlalchemy.engine') + sa_logger.setLevel(CONF.log_level) + + sqlalchemy.engine.url.make_url(CONF.sql_connection) + engine_args = { + 'pool_recycle': CONF.sql_idle_timeout, + 'echo': False, + 'convert_unicode': True} + + try: + engine = sqlalchemy.create_engine(CONF.sql_connection, **engine_args) + engine.connect = wrap_db_error(engine.connect, CONF) + engine.connect() + except Exception, err: + msg = _("Error configuring registry database with supplied " + "sql_connection '%s'. " + "Got error:\n%s") % (CONF.sql_connection, err) + LOG.error(msg) + raise + + if CONF.db_auto_create: + LOG.info(_('auto-creating staccato registry DB')) + models.register_models(engine) + try: + migration.version_control(CONF) + except exceptions.StaccatoDataBaseException: + # only arises when the DB exists and is under version control + pass + else: + LOG.info(_('not auto-creating staccato registry DB')) + + return engine + + +def is_db_connection_error(args): + """Return True if error in connecting to db.""" + # NOTE(adam_g): This is currently MySQL specific and needs to be extended + # to support Postgres and others. + conn_err_codes = ('2002', '2003', '2006') + for err_code in conn_err_codes: + if args.find(err_code) != -1: + return True + return False + + +def wrap_db_error(f, CONF): + """Retry DB connection. Copied from nova and modified.""" + def _wrap(*args, **kwargs): + try: + return f(*args, **kwargs) + except sqlalchemy.exc.OperationalError, e: + if not is_db_connection_error(e.args[0]): + raise + + remaining_attempts = CONF.sql_max_retries + while True: + LOG.warning(_('SQL connection failed. %d attempts left.'), + remaining_attempts) + remaining_attempts -= 1 + time.sleep(CONF.sql_retry_interval) + try: + return f(*args, **kwargs) + except sqlalchemy.exc.OperationalError, e: + if (remaining_attempts == 0 or + not is_db_connection_error(e.args[0])): + raise + except sqlalchemy.exc.DBAPIError: + raise + except sqlalchemy.exc.DBAPIError: + raise + _wrap.func_name = f.func_name + return _wrap diff --git a/staccato/db/migrate_repo/__init__.py b/staccato/db/migrate_repo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/staccato/db/migrate_repo/versions/__init__.py b/staccato/db/migrate_repo/versions/__init__.py new file mode 100644 index 0000000..7fcc62c --- /dev/null +++ b/staccato/db/migrate_repo/versions/__init__.py @@ -0,0 +1 @@ +__author__ = 'jbresnah' diff --git a/staccato/db/migration.py b/staccato/db/migration.py new file mode 100644 index 0000000..d4d64bf --- /dev/null +++ b/staccato/db/migration.py @@ -0,0 +1,117 @@ +import os + +from migrate.versioning import api as versioning_api +try: + from migrate.versioning import exceptions as versioning_exceptions +except ImportError: + from migrate import exceptions as versioning_exceptions +from migrate.versioning import repository as versioning_repository + +from staccato.common import exceptions +import staccato.openstack.common.log as logging + +LOG = logging.getLogger(__name__) + + +def db_version(CONF): + """ + Return the database's current migration number + + :retval version number + """ + repo_path = get_migrate_repo_path() + sql_connection = CONF.sql_connection + try: + return versioning_api.db_version(sql_connection, repo_path) + except versioning_exceptions.DatabaseNotControlledError, e: + msg = (_("database '%(sql_connection)s' is not under " + "migration control") % locals()) + raise exceptions.StaccatoDataBaseException(msg) + + +def upgrade(CONF, version=None): + """ + Upgrade the database's current migration level + + :param version: version to upgrade (defaults to latest) + :retval version number + """ + db_version() # Ensure db is under migration control + repo_path = get_migrate_repo_path() + sql_connection = CONF.sql_connection + version_str = version or 'latest' + LOG.info(_("Upgrading %(sql_connection)s to version %(version_str)s") % + locals()) + return versioning_api.upgrade(sql_connection, repo_path, version) + + +def downgrade(CONF, version): + """ + Downgrade the database's current migration level + + :param version: version to downgrade to + :retval version number + """ + db_version() # Ensure db is under migration control + repo_path = get_migrate_repo_path() + sql_connection = CONF.sql_connection + LOG.info(_("Downgrading %(sql_connection)s to version %(version)s") % + locals()) + return versioning_api.downgrade(sql_connection, repo_path, version) + + +def version_control(CONF, version=None): + """ + Place a database under migration control + """ + sql_connection = CONF.sql_connection + try: + _version_control(CONF, version) + except versioning_exceptions.DatabaseAlreadyControlledError, e: + msg = (_("database '%(sql_connection)s' is already under migration " + "control") % locals()) + raise exceptions.StaccatoDataBaseException(msg) + + +def _version_control(CONF, version): + """ + Place a database under migration control + + This will only set the specific version of a database, it won't + run any migrations. + """ + repo_path = get_migrate_repo_path() + sql_connection = CONF.sql_connection + if version is None: + version = versioning_repository.Repository(repo_path).latest + return versioning_api.version_control(sql_connection, repo_path, version) + + +def db_sync(CONF, version=None, current_version=None): + """ + Place a database under migration control and perform an upgrade + + :retval version number + """ + sql_connection = CONF.sql_connection + try: + _version_control(current_version or 0) + except versioning_exceptions.DatabaseAlreadyControlledError, e: + pass + + if current_version is None: + current_version = int(db_version()) + if version is not None and int(version) < current_version: + downgrade(version=version) + elif version is None or int(version) > current_version: + upgrade(version=version) + + +def get_migrate_repo_path(): + """Get the path for the migrate repository.""" + path = os.path.join(os.path.abspath(os.path.dirname(__file__)), + 'migrate_repo') + if not os.path.exists(path): + raise exceptions.StaccatoMisconfigurationException( + "The path % should exist." % path) + return path diff --git a/staccato/db/models.py b/staccato/db/models.py new file mode 100644 index 0000000..d3f7e0a --- /dev/null +++ b/staccato/db/models.py @@ -0,0 +1,51 @@ +""" +SQLAlchemy models for staccato data +""" + +from sqlalchemy import Column +from sqlalchemy import DateTime +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy.ext.declarative import declarative_base + +from staccato.openstack.common import timeutils +from staccato.openstack.common import uuidutils + + +BASE = declarative_base() + + +class ModelBase(object): + """Base class for Nova and Glance Models""" + __table_args__ = {'mysql_engine': 'InnoDB'} + __table_initialized__ = False + __protected_attributes__ = set([ + "created_at", "updated_at"]) + + created_at = Column(DateTime, default=timeutils.utcnow, + nullable=False) + updated_at = Column(DateTime, default=timeutils.utcnow, + nullable=False, onupdate=timeutils.utcnow) + + +class XferRequest(BASE, ModelBase): + __tablename__ = 'xfer_requests' + + id = Column(String(36), primary_key=True, default=uuidutils.generate_uuid) + srcurl = Column(String(2048), nullable=False) + dsturl = Column(String(2048), nullable=False) + src_module_name = Column(String(512), nullable=False) + dst_module_name = Column(String(512), nullable=False) + state = Column(Integer(), nullable=False) + start_ndx = Column(Integer(), nullable=False, default=0) + next_ndx = Column(Integer(), nullable=False) + end_ndx = Column(Integer(), nullable=False, default=-1) + # TODO add protocol specific json documents + write_info = Column(String(512)) + read_info = Column(String(512)) + + +def register_models(engine): + models = (XferRequest,) + for model in models: + model.metadata.create_all(engine) diff --git a/staccato/openstack/__init__.py b/staccato/openstack/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/staccato/openstack/common/__init__.py b/staccato/openstack/common/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/staccato/openstack/common/context.py b/staccato/openstack/common/context.py new file mode 100644 index 0000000..f1e3f96 --- /dev/null +++ b/staccato/openstack/common/context.py @@ -0,0 +1,82 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Simple class that stores security context information in the web request. + +Projects should subclass this class if they wish to enhance the request +context or provide additional information in their specific WSGI pipeline. +""" + +import itertools + +from staccato.openstack.common import uuidutils + + +def generate_request_id(): + return 'req-%s' % uuidutils.generate_uuid() + + +class RequestContext(object): + + """ + Stores information about the security context under which the user + accesses the system, as well as additional request information. + """ + + def __init__(self, auth_token=None, user=None, tenant=None, is_admin=False, + read_only=False, show_deleted=False, request_id=None): + self.auth_token = auth_token + self.user = user + self.tenant = tenant + self.is_admin = is_admin + self.read_only = read_only + self.show_deleted = show_deleted + if not request_id: + request_id = generate_request_id() + self.request_id = request_id + + def to_dict(self): + return {'user': self.user, + 'tenant': self.tenant, + 'is_admin': self.is_admin, + 'read_only': self.read_only, + 'show_deleted': self.show_deleted, + 'auth_token': self.auth_token, + 'request_id': self.request_id} + + +def get_admin_context(show_deleted="no"): + context = RequestContext(None, + tenant=None, + is_admin=True, + show_deleted=show_deleted) + return context + + +def get_context_from_function_and_args(function, args, kwargs): + """Find an arg of type RequestContext and return it. + + This is useful in a couple of decorators where we don't + know much about the function we're wrapping. + """ + + for arg in itertools.chain(kwargs.values(), args): + if isinstance(arg, RequestContext): + return arg + + return None diff --git a/staccato/openstack/common/eventlet_backdoor.py b/staccato/openstack/common/eventlet_backdoor.py new file mode 100644 index 0000000..57b89ae --- /dev/null +++ b/staccato/openstack/common/eventlet_backdoor.py @@ -0,0 +1,89 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright (c) 2012 OpenStack Foundation. +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import print_function + +import gc +import pprint +import sys +import traceback + +import eventlet +import eventlet.backdoor +import greenlet +from oslo.config import cfg + +eventlet_backdoor_opts = [ + cfg.IntOpt('backdoor_port', + default=None, + help='port for eventlet backdoor to listen') +] + +CONF = cfg.CONF +CONF.register_opts(eventlet_backdoor_opts) + + +def _dont_use_this(): + print("Don't use this, just disconnect instead") + + +def _find_objects(t): + return filter(lambda o: isinstance(o, t), gc.get_objects()) + + +def _print_greenthreads(): + for i, gt in enumerate(_find_objects(greenlet.greenlet)): + print(i, gt) + traceback.print_stack(gt.gr_frame) + print() + + +def _print_nativethreads(): + for threadId, stack in sys._current_frames().items(): + print(threadId) + traceback.print_stack(stack) + print() + + +def initialize_if_enabled(): + backdoor_locals = { + 'exit': _dont_use_this, # So we don't exit the entire process + 'quit': _dont_use_this, # So we don't exit the entire process + 'fo': _find_objects, + 'pgt': _print_greenthreads, + 'pnt': _print_nativethreads, + } + + if CONF.backdoor_port is None: + return None + + # NOTE(johannes): The standard sys.displayhook will print the value of + # the last expression and set it to __builtin__._, which overwrites + # the __builtin__._ that gettext sets. Let's switch to using pprint + # since it won't interact poorly with gettext, and it's easier to + # read the output too. + def displayhook(val): + if val is not None: + pprint.pprint(val) + sys.displayhook = displayhook + + sock = eventlet.listen(('localhost', CONF.backdoor_port)) + port = sock.getsockname()[1] + eventlet.spawn_n(eventlet.backdoor.backdoor_server, sock, + locals=backdoor_locals) + return port diff --git a/staccato/openstack/common/excutils.py b/staccato/openstack/common/excutils.py new file mode 100644 index 0000000..b48a73f --- /dev/null +++ b/staccato/openstack/common/excutils.py @@ -0,0 +1,51 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation. +# Copyright 2012, Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Exception related utilities. +""" + +import contextlib +import logging +import sys +import traceback + +from staccato.openstack.common.gettextutils import _ + + +@contextlib.contextmanager +def save_and_reraise_exception(): + """Save current exception, run some code and then re-raise. + + In some cases the exception context can be cleared, resulting in None + being attempted to be re-raised after an exception handler is run. This + can happen when eventlet switches greenthreads or when running an + exception handler, code raises and catches an exception. In both + cases the exception context will be cleared. + + To work around this, we save the exception state, run handler code, and + then re-raise the original exception. If another exception occurs, the + saved exception is logged and the new exception is re-raised. + """ + type_, value, tb = sys.exc_info() + try: + yield + except Exception: + logging.error(_('Original exception being dropped: %s'), + traceback.format_exception(type_, value, tb)) + raise + raise type_, value, tb diff --git a/staccato/openstack/common/gettextutils.py b/staccato/openstack/common/gettextutils.py new file mode 100644 index 0000000..af7ad2d --- /dev/null +++ b/staccato/openstack/common/gettextutils.py @@ -0,0 +1,50 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012 Red Hat, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +gettext for openstack-common modules. + +Usual usage in an openstack.common module: + + from staccato.openstack.common.gettextutils import _ +""" + +import gettext +import os + +_localedir = os.environ.get('staccato'.upper() + '_LOCALEDIR') +_t = gettext.translation('staccato', localedir=_localedir, fallback=True) + + +def _(msg): + return _t.ugettext(msg) + + +def install(domain): + """Install a _() function using the given translation domain. + + Given a translation domain, install a _() function using gettext's + install() function. + + The main difference from gettext.install() is that we allow + overriding the default localedir (e.g. /usr/share/locale) using + a translation-domain-specific environment variable (e.g. + NOVA_LOCALEDIR). + """ + gettext.install(domain, + localedir=os.environ.get(domain.upper() + '_LOCALEDIR'), + unicode=True) diff --git a/staccato/openstack/common/importutils.py b/staccato/openstack/common/importutils.py new file mode 100644 index 0000000..3bd277f --- /dev/null +++ b/staccato/openstack/common/importutils.py @@ -0,0 +1,67 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Import related utilities and helper functions. +""" + +import sys +import traceback + + +def import_class(import_str): + """Returns a class from a string including module and class""" + mod_str, _sep, class_str = import_str.rpartition('.') + try: + __import__(mod_str) + return getattr(sys.modules[mod_str], class_str) + except (ValueError, AttributeError): + raise ImportError('Class %s cannot be found (%s)' % + (class_str, + traceback.format_exception(*sys.exc_info()))) + + +def import_object(import_str, *args, **kwargs): + """Import a class and return an instance of it.""" + return import_class(import_str)(*args, **kwargs) + + +def import_object_ns(name_space, import_str, *args, **kwargs): + """ + Import a class and return an instance of it, first by trying + to find the class in a default namespace, then failing back to + a full path if not found in the default namespace. + """ + import_value = "%s.%s" % (name_space, import_str) + try: + return import_class(import_value)(*args, **kwargs) + except ImportError: + return import_class(import_str)(*args, **kwargs) + + +def import_module(import_str): + """Import a module.""" + __import__(import_str) + return sys.modules[import_str] + + +def try_import(import_str, default=None): + """Try to import a module and if it fails return default.""" + try: + return import_module(import_str) + except ImportError: + return default diff --git a/staccato/openstack/common/jsonutils.py b/staccato/openstack/common/jsonutils.py new file mode 100644 index 0000000..06df677 --- /dev/null +++ b/staccato/openstack/common/jsonutils.py @@ -0,0 +1,169 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# Copyright 2011 Justin Santa Barbara +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +''' +JSON related utilities. + +This module provides a few things: + + 1) A handy function for getting an object down to something that can be + JSON serialized. See to_primitive(). + + 2) Wrappers around loads() and dumps(). The dumps() wrapper will + automatically use to_primitive() for you if needed. + + 3) This sets up anyjson to use the loads() and dumps() wrappers if anyjson + is available. +''' + + +import datetime +import functools +import inspect +import itertools +import json +import types +import xmlrpclib + +import six + +from staccato.openstack.common import timeutils + + +_nasty_type_tests = [inspect.ismodule, inspect.isclass, inspect.ismethod, + inspect.isfunction, inspect.isgeneratorfunction, + inspect.isgenerator, inspect.istraceback, inspect.isframe, + inspect.iscode, inspect.isbuiltin, inspect.isroutine, + inspect.isabstract] + +_simple_types = (types.NoneType, int, basestring, bool, float, long) + + +def to_primitive(value, convert_instances=False, convert_datetime=True, + level=0, max_depth=3): + """Convert a complex object into primitives. + + Handy for JSON serialization. We can optionally handle instances, + but since this is a recursive function, we could have cyclical + data structures. + + To handle cyclical data structures we could track the actual objects + visited in a set, but not all objects are hashable. Instead we just + track the depth of the object inspections and don't go too deep. + + Therefore, convert_instances=True is lossy ... be aware. + + """ + # handle obvious types first - order of basic types determined by running + # full tests on nova project, resulting in the following counts: + # 572754 + # 460353 + # 379632 + # 274610 + # 199918 + # 114200 + # 51817 + # 26164 + # 6491 + # 283 + # 19 + if isinstance(value, _simple_types): + return value + + if isinstance(value, datetime.datetime): + if convert_datetime: + return timeutils.strtime(value) + else: + return value + + # value of itertools.count doesn't get caught by nasty_type_tests + # and results in infinite loop when list(value) is called. + if type(value) == itertools.count: + return six.text_type(value) + + # FIXME(vish): Workaround for LP bug 852095. Without this workaround, + # tests that raise an exception in a mocked method that + # has a @wrap_exception with a notifier will fail. If + # we up the dependency to 0.5.4 (when it is released) we + # can remove this workaround. + if getattr(value, '__module__', None) == 'mox': + return 'mock' + + if level > max_depth: + return '?' + + # The try block may not be necessary after the class check above, + # but just in case ... + try: + recursive = functools.partial(to_primitive, + convert_instances=convert_instances, + convert_datetime=convert_datetime, + level=level, + max_depth=max_depth) + if isinstance(value, dict): + return dict((k, recursive(v)) for k, v in value.iteritems()) + elif isinstance(value, (list, tuple)): + return [recursive(lv) for lv in value] + + # It's not clear why xmlrpclib created their own DateTime type, but + # for our purposes, make it a datetime type which is explicitly + # handled + if isinstance(value, xmlrpclib.DateTime): + value = datetime.datetime(*tuple(value.timetuple())[:6]) + + if convert_datetime and isinstance(value, datetime.datetime): + return timeutils.strtime(value) + elif hasattr(value, 'iteritems'): + return recursive(dict(value.iteritems()), level=level + 1) + elif hasattr(value, '__iter__'): + return recursive(list(value)) + elif convert_instances and hasattr(value, '__dict__'): + # Likely an instance of something. Watch for cycles. + # Ignore class member vars. + return recursive(value.__dict__, level=level + 1) + else: + if any(test(value) for test in _nasty_type_tests): + return six.text_type(value) + return value + except TypeError: + # Class objects are tricky since they may define something like + # __iter__ defined but it isn't callable as list(). + return six.text_type(value) + + +def dumps(value, default=to_primitive, **kwargs): + return json.dumps(value, default=default, **kwargs) + + +def loads(s): + return json.loads(s) + + +def load(s): + return json.load(s) + + +try: + import anyjson +except ImportError: + pass +else: + anyjson._modules.append((__name__, 'dumps', TypeError, + 'loads', ValueError, 'load')) + anyjson.force_implementation(__name__) diff --git a/staccato/openstack/common/local.py b/staccato/openstack/common/local.py new file mode 100644 index 0000000..f1bfc82 --- /dev/null +++ b/staccato/openstack/common/local.py @@ -0,0 +1,48 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Greenthread local storage of variables using weak references""" + +import weakref + +from eventlet import corolocal + + +class WeakLocal(corolocal.local): + def __getattribute__(self, attr): + rval = corolocal.local.__getattribute__(self, attr) + if rval: + # NOTE(mikal): this bit is confusing. What is stored is a weak + # reference, not the value itself. We therefore need to lookup + # the weak reference and return the inner value here. + rval = rval() + return rval + + def __setattr__(self, attr, value): + value = weakref.ref(value) + return corolocal.local.__setattr__(self, attr, value) + + +# NOTE(mikal): the name "store" should be deprecated in the future +store = WeakLocal() + +# A "weak" store uses weak references and allows an object to fall out of scope +# when it falls out of scope in the code that uses the thread local storage. A +# "strong" store will hold a reference to the object so that it never falls out +# of scope. +weak_store = WeakLocal() +strong_store = corolocal.local diff --git a/staccato/openstack/common/log.py b/staccato/openstack/common/log.py new file mode 100644 index 0000000..e552b23 --- /dev/null +++ b/staccato/openstack/common/log.py @@ -0,0 +1,566 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation. +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Openstack logging handler. + +This module adds to logging functionality by adding the option to specify +a context object when calling the various log methods. If the context object +is not specified, default formatting is used. Additionally, an instance uuid +may be passed as part of the log message, which is intended to make it easier +for admins to find messages related to a specific instance. + +It also allows setting of formatting information through conf. + +""" + +import ConfigParser +import cStringIO +import inspect +import itertools +import logging +import logging.config +import logging.handlers +import os +import sys +import traceback + +from oslo.config import cfg + +from staccato.openstack.common.gettextutils import _ +from staccato.openstack.common import jsonutils +from staccato.openstack.common import local +from staccato.openstack.common import notifier + + +_DEFAULT_LOG_DATE_FORMAT = "%Y-%m-%d %H:%M:%S" + +common_cli_opts = [ + cfg.BoolOpt('debug', + short='d', + default=False, + help='Print debugging output (set logging level to ' + 'DEBUG instead of default WARNING level).'), + cfg.BoolOpt('verbose', + short='v', + default=False, + help='Print more verbose output (set logging level to ' + 'INFO instead of default WARNING level).'), +] + +logging_cli_opts = [ + cfg.StrOpt('log-config', + metavar='PATH', + help='If this option is specified, the logging configuration ' + 'file specified is used and overrides any other logging ' + 'options specified. Please see the Python logging module ' + 'documentation for details on logging configuration ' + 'files.'), + cfg.StrOpt('log-format', + default=None, + metavar='FORMAT', + help='A logging.Formatter log message format string which may ' + 'use any of the available logging.LogRecord attributes. ' + 'This option is deprecated. Please use ' + 'logging_context_format_string and ' + 'logging_default_format_string instead.'), + cfg.StrOpt('log-date-format', + default=_DEFAULT_LOG_DATE_FORMAT, + metavar='DATE_FORMAT', + help='Format string for %%(asctime)s in log records. ' + 'Default: %(default)s'), + cfg.StrOpt('log-file', + metavar='PATH', + deprecated_name='logfile', + help='(Optional) Name of log file to output to. ' + 'If no default is set, logging will go to stdout.'), + cfg.StrOpt('log-dir', + deprecated_name='logdir', + help='(Optional) The base directory used for relative ' + '--log-file paths'), + cfg.BoolOpt('use-syslog', + default=False, + help='Use syslog for logging.'), + cfg.StrOpt('syslog-log-facility', + default='LOG_USER', + help='syslog facility to receive log lines') +] + +generic_log_opts = [ + cfg.BoolOpt('use_stderr', + default=True, + help='Log output to standard error') +] + +log_opts = [ + cfg.StrOpt('logging_context_format_string', + default='%(asctime)s.%(msecs)03d %(process)d %(levelname)s ' + '%(name)s [%(request_id)s %(user)s %(tenant)s] ' + '%(instance)s%(message)s', + help='format string to use for log messages with context'), + cfg.StrOpt('logging_default_format_string', + default='%(asctime)s.%(msecs)03d %(process)d %(levelname)s ' + '%(name)s [-] %(instance)s%(message)s', + help='format string to use for log messages without context'), + cfg.StrOpt('logging_debug_format_suffix', + default='%(funcName)s %(pathname)s:%(lineno)d', + help='data to append to log format when level is DEBUG'), + cfg.StrOpt('logging_exception_prefix', + default='%(asctime)s.%(msecs)03d %(process)d TRACE %(name)s ' + '%(instance)s', + help='prefix each line of exception output with this format'), + cfg.ListOpt('default_log_levels', + default=[ + 'amqplib=WARN', + 'sqlalchemy=WARN', + 'boto=WARN', + 'suds=INFO', + 'keystone=INFO', + 'eventlet.wsgi.server=WARN' + ], + help='list of logger=LEVEL pairs'), + cfg.BoolOpt('publish_errors', + default=False, + help='publish error events'), + cfg.BoolOpt('fatal_deprecations', + default=False, + help='make deprecations fatal'), + + # NOTE(mikal): there are two options here because sometimes we are handed + # a full instance (and could include more information), and other times we + # are just handed a UUID for the instance. + cfg.StrOpt('instance_format', + default='[instance: %(uuid)s] ', + help='If an instance is passed with the log message, format ' + 'it like this'), + cfg.StrOpt('instance_uuid_format', + default='[instance: %(uuid)s] ', + help='If an instance UUID is passed with the log message, ' + 'format it like this'), +] + +CONF = cfg.CONF +CONF.register_cli_opts(common_cli_opts) +CONF.register_cli_opts(logging_cli_opts) +CONF.register_opts(generic_log_opts) +CONF.register_opts(log_opts) + +# our new audit level +# NOTE(jkoelker) Since we synthesized an audit level, make the logging +# module aware of it so it acts like other levels. +logging.AUDIT = logging.INFO + 1 +logging.addLevelName(logging.AUDIT, 'AUDIT') + + +try: + NullHandler = logging.NullHandler +except AttributeError: # NOTE(jkoelker) NullHandler added in Python 2.7 + class NullHandler(logging.Handler): + def handle(self, record): + pass + + def emit(self, record): + pass + + def createLock(self): + self.lock = None + + +def _dictify_context(context): + if context is None: + return None + if not isinstance(context, dict) and getattr(context, 'to_dict', None): + context = context.to_dict() + return context + + +def _get_binary_name(): + return os.path.basename(inspect.stack()[-1][1]) + + +def _get_log_file_path(binary=None): + logfile = CONF.log_file + logdir = CONF.log_dir + + if logfile and not logdir: + return logfile + + if logfile and logdir: + return os.path.join(logdir, logfile) + + if logdir: + binary = binary or _get_binary_name() + return '%s.log' % (os.path.join(logdir, binary),) + + +class BaseLoggerAdapter(logging.LoggerAdapter): + + def audit(self, msg, *args, **kwargs): + self.log(logging.AUDIT, msg, *args, **kwargs) + + +class LazyAdapter(BaseLoggerAdapter): + def __init__(self, name='unknown', version='unknown'): + self._logger = None + self.extra = {} + self.name = name + self.version = version + + @property + def logger(self): + if not self._logger: + self._logger = getLogger(self.name, self.version) + return self._logger + + +class ContextAdapter(BaseLoggerAdapter): + warn = logging.LoggerAdapter.warning + + def __init__(self, logger, project_name, version_string): + self.logger = logger + self.project = project_name + self.version = version_string + + @property + def handlers(self): + return self.logger.handlers + + def deprecated(self, msg, *args, **kwargs): + stdmsg = _("Deprecated: %s") % msg + if CONF.fatal_deprecations: + self.critical(stdmsg, *args, **kwargs) + raise DeprecatedConfig(msg=stdmsg) + else: + self.warn(stdmsg, *args, **kwargs) + + def process(self, msg, kwargs): + if 'extra' not in kwargs: + kwargs['extra'] = {} + extra = kwargs['extra'] + + context = kwargs.pop('context', None) + if not context: + context = getattr(local.store, 'context', None) + if context: + extra.update(_dictify_context(context)) + + instance = kwargs.pop('instance', None) + instance_extra = '' + if instance: + instance_extra = CONF.instance_format % instance + else: + instance_uuid = kwargs.pop('instance_uuid', None) + if instance_uuid: + instance_extra = (CONF.instance_uuid_format + % {'uuid': instance_uuid}) + extra.update({'instance': instance_extra}) + + extra.update({"project": self.project}) + extra.update({"version": self.version}) + extra['extra'] = extra.copy() + return msg, kwargs + + +class JSONFormatter(logging.Formatter): + def __init__(self, fmt=None, datefmt=None): + # NOTE(jkoelker) we ignore the fmt argument, but its still there + # since logging.config.fileConfig passes it. + self.datefmt = datefmt + + def formatException(self, ei, strip_newlines=True): + lines = traceback.format_exception(*ei) + if strip_newlines: + lines = [itertools.ifilter( + lambda x: x, + line.rstrip().splitlines()) for line in lines] + lines = list(itertools.chain(*lines)) + return lines + + def format(self, record): + message = {'message': record.getMessage(), + 'asctime': self.formatTime(record, self.datefmt), + 'name': record.name, + 'msg': record.msg, + 'args': record.args, + 'levelname': record.levelname, + 'levelno': record.levelno, + 'pathname': record.pathname, + 'filename': record.filename, + 'module': record.module, + 'lineno': record.lineno, + 'funcname': record.funcName, + 'created': record.created, + 'msecs': record.msecs, + 'relative_created': record.relativeCreated, + 'thread': record.thread, + 'thread_name': record.threadName, + 'process_name': record.processName, + 'process': record.process, + 'traceback': None} + + if hasattr(record, 'extra'): + message['extra'] = record.extra + + if record.exc_info: + message['traceback'] = self.formatException(record.exc_info) + + return jsonutils.dumps(message) + + +class PublishErrorsHandler(logging.Handler): + def emit(self, record): + if ('staccato.openstack.common.notifier.log_notifier' in + CONF.notification_driver): + return + notifier.api.notify(None, 'error.publisher', + 'error_notification', + notifier.api.ERROR, + dict(error=record.msg)) + + +def _create_logging_excepthook(product_name): + def logging_excepthook(type, value, tb): + extra = {} + if CONF.verbose: + extra['exc_info'] = (type, value, tb) + getLogger(product_name).critical(str(value), **extra) + return logging_excepthook + + +class LogConfigError(Exception): + + message = _('Error loading logging config %(log_config)s: %(err_msg)s') + + def __init__(self, log_config, err_msg): + self.log_config = log_config + self.err_msg = err_msg + + def __str__(self): + return self.message % dict(log_config=self.log_config, + err_msg=self.err_msg) + + +def _load_log_config(log_config): + try: + logging.config.fileConfig(log_config) + except ConfigParser.Error as exc: + raise LogConfigError(log_config, str(exc)) + + +def setup(product_name): + """Setup logging.""" + if CONF.log_config: + _load_log_config(CONF.log_config) + else: + _setup_logging_from_conf() + sys.excepthook = _create_logging_excepthook(product_name) + + +def set_defaults(logging_context_format_string): + cfg.set_defaults(log_opts, + logging_context_format_string= + logging_context_format_string) + + +def _find_facility_from_conf(): + facility_names = logging.handlers.SysLogHandler.facility_names + facility = getattr(logging.handlers.SysLogHandler, + CONF.syslog_log_facility, + None) + + if facility is None and CONF.syslog_log_facility in facility_names: + facility = facility_names.get(CONF.syslog_log_facility) + + if facility is None: + valid_facilities = facility_names.keys() + consts = ['LOG_AUTH', 'LOG_AUTHPRIV', 'LOG_CRON', 'LOG_DAEMON', + 'LOG_FTP', 'LOG_KERN', 'LOG_LPR', 'LOG_MAIL', 'LOG_NEWS', + 'LOG_AUTH', 'LOG_SYSLOG', 'LOG_USER', 'LOG_UUCP', + 'LOG_LOCAL0', 'LOG_LOCAL1', 'LOG_LOCAL2', 'LOG_LOCAL3', + 'LOG_LOCAL4', 'LOG_LOCAL5', 'LOG_LOCAL6', 'LOG_LOCAL7'] + valid_facilities.extend(consts) + raise TypeError(_('syslog facility must be one of: %s') % + ', '.join("'%s'" % fac + for fac in valid_facilities)) + + return facility + + +def _setup_logging_from_conf(): + log_root = getLogger(None).logger + for handler in log_root.handlers: + log_root.removeHandler(handler) + + if CONF.use_syslog: + facility = _find_facility_from_conf() + syslog = logging.handlers.SysLogHandler(address='/dev/log', + facility=facility) + log_root.addHandler(syslog) + + logpath = _get_log_file_path() + if logpath: + filelog = logging.handlers.WatchedFileHandler(logpath) + log_root.addHandler(filelog) + + if CONF.use_stderr: + streamlog = ColorHandler() + log_root.addHandler(streamlog) + + elif not CONF.log_file: + # pass sys.stdout as a positional argument + # python2.6 calls the argument strm, in 2.7 it's stream + streamlog = logging.StreamHandler(sys.stdout) + log_root.addHandler(streamlog) + + if CONF.publish_errors: + log_root.addHandler(PublishErrorsHandler(logging.ERROR)) + + datefmt = CONF.log_date_format + for handler in log_root.handlers: + # NOTE(alaski): CONF.log_format overrides everything currently. This + # should be deprecated in favor of context aware formatting. + if CONF.log_format: + handler.setFormatter(logging.Formatter(fmt=CONF.log_format, + datefmt=datefmt)) + log_root.info('Deprecated: log_format is now deprecated and will ' + 'be removed in the next release') + else: + handler.setFormatter(ContextFormatter(datefmt=datefmt)) + + if CONF.debug: + log_root.setLevel(logging.DEBUG) + elif CONF.verbose: + log_root.setLevel(logging.INFO) + else: + log_root.setLevel(logging.WARNING) + + for pair in CONF.default_log_levels: + mod, _sep, level_name = pair.partition('=') + level = logging.getLevelName(level_name) + logger = logging.getLogger(mod) + logger.setLevel(level) + +_loggers = {} + + +def getLogger(name='unknown', version='unknown'): + if name not in _loggers: + _loggers[name] = ContextAdapter(logging.getLogger(name), + name, + version) + return _loggers[name] + + +def getLazyLogger(name='unknown', version='unknown'): + """ + create a pass-through logger that does not create the real logger + until it is really needed and delegates all calls to the real logger + once it is created + """ + return LazyAdapter(name, version) + + +class WritableLogger(object): + """A thin wrapper that responds to `write` and logs.""" + + def __init__(self, logger, level=logging.INFO): + self.logger = logger + self.level = level + + def write(self, msg): + self.logger.log(self.level, msg) + + +class ContextFormatter(logging.Formatter): + """A context.RequestContext aware formatter configured through flags. + + The flags used to set format strings are: logging_context_format_string + and logging_default_format_string. You can also specify + logging_debug_format_suffix to append extra formatting if the log level is + debug. + + For information about what variables are available for the formatter see: + http://docs.python.org/library/logging.html#formatter + + """ + + def format(self, record): + """Uses contextstring if request_id is set, otherwise default.""" + # NOTE(sdague): default the fancier formating params + # to an empty string so we don't throw an exception if + # they get used + for key in ('instance', 'color'): + if key not in record.__dict__: + record.__dict__[key] = '' + + if record.__dict__.get('request_id', None): + self._fmt = CONF.logging_context_format_string + else: + self._fmt = CONF.logging_default_format_string + + if (record.levelno == logging.DEBUG and + CONF.logging_debug_format_suffix): + self._fmt += " " + CONF.logging_debug_format_suffix + + # Cache this on the record, Logger will respect our formated copy + if record.exc_info: + record.exc_text = self.formatException(record.exc_info, record) + return logging.Formatter.format(self, record) + + def formatException(self, exc_info, record=None): + """Format exception output with CONF.logging_exception_prefix.""" + if not record: + return logging.Formatter.formatException(self, exc_info) + + stringbuffer = cStringIO.StringIO() + traceback.print_exception(exc_info[0], exc_info[1], exc_info[2], + None, stringbuffer) + lines = stringbuffer.getvalue().split('\n') + stringbuffer.close() + + if CONF.logging_exception_prefix.find('%(asctime)') != -1: + record.asctime = self.formatTime(record, self.datefmt) + + formatted_lines = [] + for line in lines: + pl = CONF.logging_exception_prefix % record.__dict__ + fl = '%s%s' % (pl, line) + formatted_lines.append(fl) + return '\n'.join(formatted_lines) + + +class ColorHandler(logging.StreamHandler): + LEVEL_COLORS = { + logging.DEBUG: '\033[00;32m', # GREEN + logging.INFO: '\033[00;36m', # CYAN + logging.AUDIT: '\033[01;36m', # BOLD CYAN + logging.WARN: '\033[01;33m', # BOLD YELLOW + logging.ERROR: '\033[01;31m', # BOLD RED + logging.CRITICAL: '\033[01;31m', # BOLD RED + } + + def format(self, record): + record.color = self.LEVEL_COLORS[record.levelno] + return logging.StreamHandler.format(self, record) + + +class DeprecatedConfig(Exception): + message = _("Fatal call to deprecated config: %(msg)s") + + def __init__(self, msg): + super(Exception, self).__init__(self.message % dict(msg=msg)) diff --git a/staccato/openstack/common/loopingcall.py b/staccato/openstack/common/loopingcall.py new file mode 100644 index 0000000..0b2daa5 --- /dev/null +++ b/staccato/openstack/common/loopingcall.py @@ -0,0 +1,147 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# Copyright 2011 Justin Santa Barbara +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import sys + +from eventlet import event +from eventlet import greenthread + +from staccato.openstack.common.gettextutils import _ +from staccato.openstack.common import log as logging +from staccato.openstack.common import timeutils + +LOG = logging.getLogger(__name__) + + +class LoopingCallDone(Exception): + """Exception to break out and stop a LoopingCall. + + The poll-function passed to LoopingCall can raise this exception to + break out of the loop normally. This is somewhat analogous to + StopIteration. + + An optional return-value can be included as the argument to the exception; + this return-value will be returned by LoopingCall.wait() + + """ + + def __init__(self, retvalue=True): + """:param retvalue: Value that LoopingCall.wait() should return.""" + self.retvalue = retvalue + + +class LoopingCallBase(object): + def __init__(self, f=None, *args, **kw): + self.args = args + self.kw = kw + self.f = f + self._running = False + self.done = None + + def stop(self): + self._running = False + + def wait(self): + return self.done.wait() + + +class FixedIntervalLoopingCall(LoopingCallBase): + """A fixed interval looping call.""" + + def start(self, interval, initial_delay=None): + self._running = True + done = event.Event() + + def _inner(): + if initial_delay: + greenthread.sleep(initial_delay) + + try: + while self._running: + start = timeutils.utcnow() + self.f(*self.args, **self.kw) + end = timeutils.utcnow() + if not self._running: + break + delay = interval - timeutils.delta_seconds(start, end) + if delay <= 0: + LOG.warn(_('task run outlasted interval by %s sec') % + -delay) + greenthread.sleep(delay if delay > 0 else 0) + except LoopingCallDone as e: + self.stop() + done.send(e.retvalue) + except Exception: + LOG.exception(_('in fixed duration looping call')) + done.send_exception(*sys.exc_info()) + return + else: + done.send(True) + + self.done = done + + greenthread.spawn_n(_inner) + return self.done + + +# TODO(mikal): this class name is deprecated in Havana and should be removed +# in the I release +LoopingCall = FixedIntervalLoopingCall + + +class DynamicLoopingCall(LoopingCallBase): + """A looping call which sleeps until the next known event. + + The function called should return how long to sleep for before being + called again. + """ + + def start(self, initial_delay=None, periodic_interval_max=None): + self._running = True + done = event.Event() + + def _inner(): + if initial_delay: + greenthread.sleep(initial_delay) + + try: + while self._running: + idle = self.f(*self.args, **self.kw) + if not self._running: + break + + if periodic_interval_max is not None: + idle = min(idle, periodic_interval_max) + LOG.debug(_('Dynamic looping call sleeping for %.02f ' + 'seconds'), idle) + greenthread.sleep(idle) + except LoopingCallDone as e: + self.stop() + done.send(e.retvalue) + except Exception: + LOG.exception(_('in dynamic looping call')) + done.send_exception(*sys.exc_info()) + return + else: + done.send(True) + + self.done = done + + greenthread.spawn(_inner) + return self.done diff --git a/staccato/openstack/common/network_utils.py b/staccato/openstack/common/network_utils.py new file mode 100644 index 0000000..2f988ca --- /dev/null +++ b/staccato/openstack/common/network_utils.py @@ -0,0 +1,69 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Network-related utilities and helper functions. +""" + +from staccato.openstack.common import log as logging + + +LOG = logging.getLogger(__name__) + + +def parse_host_port(address, default_port=None): + """ + Interpret a string as a host:port pair. + An IPv6 address MUST be escaped if accompanied by a port, + because otherwise ambiguity ensues: 2001:db8:85a3::8a2e:370:7334 + means both [2001:db8:85a3::8a2e:370:7334] and + [2001:db8:85a3::8a2e:370]:7334. + + >>> parse_host_port('server01:80') + ('server01', 80) + >>> parse_host_port('server01') + ('server01', None) + >>> parse_host_port('server01', default_port=1234) + ('server01', 1234) + >>> parse_host_port('[::1]:80') + ('::1', 80) + >>> parse_host_port('[::1]') + ('::1', None) + >>> parse_host_port('[::1]', default_port=1234) + ('::1', 1234) + >>> parse_host_port('2001:db8:85a3::8a2e:370:7334', default_port=1234) + ('2001:db8:85a3::8a2e:370:7334', 1234) + + """ + if address[0] == '[': + # Escaped ipv6 + _host, _port = address[1:].split(']') + host = _host + if ':' in _port: + port = _port.split(':')[1] + else: + port = default_port + else: + if address.count(':') == 1: + host, port = address.split(':') + else: + # 0 means ipv4, >1 means ipv6. + # We prohibit unescaped ipv6 addresses with port. + host = address + port = default_port + + return (host, None if port is None else int(port)) diff --git a/staccato/openstack/common/notifier/__init__.py b/staccato/openstack/common/notifier/__init__.py new file mode 100644 index 0000000..45c3b46 --- /dev/null +++ b/staccato/openstack/common/notifier/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. diff --git a/staccato/openstack/common/notifier/api.py b/staccato/openstack/common/notifier/api.py new file mode 100644 index 0000000..195a04c --- /dev/null +++ b/staccato/openstack/common/notifier/api.py @@ -0,0 +1,182 @@ +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import uuid + +from oslo.config import cfg + +from staccato.openstack.common import context +from staccato.openstack.common.gettextutils import _ +from staccato.openstack.common import importutils +from staccato.openstack.common import jsonutils +from staccato.openstack.common import log as logging +from staccato.openstack.common import timeutils + + +LOG = logging.getLogger(__name__) + +notifier_opts = [ + cfg.MultiStrOpt('notification_driver', + default=[], + help='Driver or drivers to handle sending notifications'), + cfg.StrOpt('default_notification_level', + default='INFO', + help='Default notification level for outgoing notifications'), + cfg.StrOpt('default_publisher_id', + default='$host', + help='Default publisher_id for outgoing notifications'), +] + +CONF = cfg.CONF +CONF.register_opts(notifier_opts) + +WARN = 'WARN' +INFO = 'INFO' +ERROR = 'ERROR' +CRITICAL = 'CRITICAL' +DEBUG = 'DEBUG' + +log_levels = (DEBUG, WARN, INFO, ERROR, CRITICAL) + + +class BadPriorityException(Exception): + pass + + +def notify_decorator(name, fn): + """ decorator for notify which is used from utils.monkey_patch() + + :param name: name of the function + :param function: - object of the function + :returns: function -- decorated function + + """ + def wrapped_func(*args, **kwarg): + body = {} + body['args'] = [] + body['kwarg'] = {} + for arg in args: + body['args'].append(arg) + for key in kwarg: + body['kwarg'][key] = kwarg[key] + + ctxt = context.get_context_from_function_and_args(fn, args, kwarg) + notify(ctxt, + CONF.default_publisher_id, + name, + CONF.default_notification_level, + body) + return fn(*args, **kwarg) + return wrapped_func + + +def publisher_id(service, host=None): + if not host: + host = CONF.host + return "%s.%s" % (service, host) + + +def notify(context, publisher_id, event_type, priority, payload): + """Sends a notification using the specified driver + + :param publisher_id: the source worker_type.host of the message + :param event_type: the literal type of event (ex. Instance Creation) + :param priority: patterned after the enumeration of Python logging + levels in the set (DEBUG, WARN, INFO, ERROR, CRITICAL) + :param payload: A python dictionary of attributes + + Outgoing message format includes the above parameters, and appends the + following: + + message_id + a UUID representing the id for this notification + + timestamp + the GMT timestamp the notification was sent at + + The composite message will be constructed as a dictionary of the above + attributes, which will then be sent via the transport mechanism defined + by the driver. + + Message example:: + + {'message_id': str(uuid.uuid4()), + 'publisher_id': 'compute.host1', + 'timestamp': timeutils.utcnow(), + 'priority': 'WARN', + 'event_type': 'compute.create_instance', + 'payload': {'instance_id': 12, ... }} + + """ + if priority not in log_levels: + raise BadPriorityException( + _('%s not in valid priorities') % priority) + + # Ensure everything is JSON serializable. + payload = jsonutils.to_primitive(payload, convert_instances=True) + + msg = dict(message_id=str(uuid.uuid4()), + publisher_id=publisher_id, + event_type=event_type, + priority=priority, + payload=payload, + timestamp=str(timeutils.utcnow())) + + for driver in _get_drivers(): + try: + driver.notify(context, msg) + except Exception as e: + LOG.exception(_("Problem '%(e)s' attempting to " + "send to notification system. " + "Payload=%(payload)s") + % dict(e=e, payload=payload)) + + +_drivers = None + + +def _get_drivers(): + """Instantiate, cache, and return drivers based on the CONF.""" + global _drivers + if _drivers is None: + _drivers = {} + for notification_driver in CONF.notification_driver: + add_driver(notification_driver) + + return _drivers.values() + + +def add_driver(notification_driver): + """Add a notification driver at runtime.""" + # Make sure the driver list is initialized. + _get_drivers() + if isinstance(notification_driver, basestring): + # Load and add + try: + driver = importutils.import_module(notification_driver) + _drivers[notification_driver] = driver + except ImportError: + LOG.exception(_("Failed to load notifier %s. " + "These notifications will not be sent.") % + notification_driver) + else: + # Driver is already loaded; just add the object. + _drivers[notification_driver] = notification_driver + + +def _reset_drivers(): + """Used by unit tests to reset the drivers.""" + global _drivers + _drivers = None diff --git a/staccato/openstack/common/notifier/log_notifier.py b/staccato/openstack/common/notifier/log_notifier.py new file mode 100644 index 0000000..f4cf0c3 --- /dev/null +++ b/staccato/openstack/common/notifier/log_notifier.py @@ -0,0 +1,35 @@ +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from oslo.config import cfg + +from staccato.openstack.common import jsonutils +from staccato.openstack.common import log as logging + + +CONF = cfg.CONF + + +def notify(_context, message): + """Notifies the recipient of the desired event given the model. + Log notifications using openstack's default logging system""" + + priority = message.get('priority', + CONF.default_notification_level) + priority = priority.lower() + logger = logging.getLogger( + 'staccato.openstack.common.notification.%s' % + message['event_type']) + getattr(logger, priority)(jsonutils.dumps(message)) diff --git a/staccato/openstack/common/notifier/no_op_notifier.py b/staccato/openstack/common/notifier/no_op_notifier.py new file mode 100644 index 0000000..bc7a56c --- /dev/null +++ b/staccato/openstack/common/notifier/no_op_notifier.py @@ -0,0 +1,19 @@ +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +def notify(_context, message): + """Notifies the recipient of the desired event given the model""" + pass diff --git a/staccato/openstack/common/notifier/rpc_notifier.py b/staccato/openstack/common/notifier/rpc_notifier.py new file mode 100644 index 0000000..a08b1ea --- /dev/null +++ b/staccato/openstack/common/notifier/rpc_notifier.py @@ -0,0 +1,46 @@ +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from oslo.config import cfg + +from staccato.openstack.common import context as req_context +from staccato.openstack.common.gettextutils import _ +from staccato.openstack.common import log as logging +from staccato.openstack.common import rpc + +LOG = logging.getLogger(__name__) + +notification_topic_opt = cfg.ListOpt( + 'notification_topics', default=['notifications', ], + help='AMQP topic used for openstack notifications') + +CONF = cfg.CONF +CONF.register_opt(notification_topic_opt) + + +def notify(context, message): + """Sends a notification via RPC""" + if not context: + context = req_context.get_admin_context() + priority = message.get('priority', + CONF.default_notification_level) + priority = priority.lower() + for topic in CONF.notification_topics: + topic = '%s.%s' % (topic, priority) + try: + rpc.notify(context, topic, message) + except Exception: + LOG.exception(_("Could not send notification to %(topic)s. " + "Payload=%(message)s"), locals()) diff --git a/staccato/openstack/common/notifier/rpc_notifier2.py b/staccato/openstack/common/notifier/rpc_notifier2.py new file mode 100644 index 0000000..3a66ae3 --- /dev/null +++ b/staccato/openstack/common/notifier/rpc_notifier2.py @@ -0,0 +1,52 @@ +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +'''messaging based notification driver, with message envelopes''' + +from oslo.config import cfg + +from staccato.openstack.common import context as req_context +from staccato.openstack.common.gettextutils import _ +from staccato.openstack.common import log as logging +from staccato.openstack.common import rpc + +LOG = logging.getLogger(__name__) + +notification_topic_opt = cfg.ListOpt( + 'topics', default=['notifications', ], + help='AMQP topic(s) used for openstack notifications') + +opt_group = cfg.OptGroup(name='rpc_notifier2', + title='Options for rpc_notifier2') + +CONF = cfg.CONF +CONF.register_group(opt_group) +CONF.register_opt(notification_topic_opt, opt_group) + + +def notify(context, message): + """Sends a notification via RPC""" + if not context: + context = req_context.get_admin_context() + priority = message.get('priority', + CONF.default_notification_level) + priority = priority.lower() + for topic in CONF.rpc_notifier2.topics: + topic = '%s.%s' % (topic, priority) + try: + rpc.notify(context, topic, message, envelope=True) + except Exception: + LOG.exception(_("Could not send notification to %(topic)s. " + "Payload=%(message)s"), locals()) diff --git a/staccato/openstack/common/notifier/test_notifier.py b/staccato/openstack/common/notifier/test_notifier.py new file mode 100644 index 0000000..96c1746 --- /dev/null +++ b/staccato/openstack/common/notifier/test_notifier.py @@ -0,0 +1,22 @@ +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +NOTIFICATIONS = [] + + +def notify(_context, message): + """Test notifier, stores notifications in memory for unittests.""" + NOTIFICATIONS.append(message) diff --git a/staccato/openstack/common/policy.py b/staccato/openstack/common/policy.py new file mode 100644 index 0000000..fb96cdc --- /dev/null +++ b/staccato/openstack/common/policy.py @@ -0,0 +1,780 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright (c) 2012 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Common Policy Engine Implementation + +Policies can be expressed in one of two forms: A list of lists, or a +string written in the new policy language. + +In the list-of-lists representation, each check inside the innermost +list is combined as with an "and" conjunction--for that check to pass, +all the specified checks must pass. These innermost lists are then +combined as with an "or" conjunction. This is the original way of +expressing policies, but there now exists a new way: the policy +language. + +In the policy language, each check is specified the same way as in the +list-of-lists representation: a simple "a:b" pair that is matched to +the correct code to perform that check. However, conjunction +operators are available, allowing for more expressiveness in crafting +policies. + +As an example, take the following rule, expressed in the list-of-lists +representation:: + + [["role:admin"], ["project_id:%(project_id)s", "role:projectadmin"]] + +In the policy language, this becomes:: + + role:admin or (project_id:%(project_id)s and role:projectadmin) + +The policy language also has the "not" operator, allowing a richer +policy rule:: + + project_id:%(project_id)s and not role:dunce + +Finally, two special policy checks should be mentioned; the policy +check "@" will always accept an access, and the policy check "!" will +always reject an access. (Note that if a rule is either the empty +list ("[]") or the empty string, this is equivalent to the "@" policy +check.) Of these, the "!" policy check is probably the most useful, +as it allows particular rules to be explicitly disabled. +""" + +import abc +import re +import urllib + +import six +import urllib2 + +from staccato.openstack.common.gettextutils import _ +from staccato.openstack.common import jsonutils +from staccato.openstack.common import log as logging + + +LOG = logging.getLogger(__name__) + + +_rules = None +_checks = {} + + +class Rules(dict): + """ + A store for rules. Handles the default_rule setting directly. + """ + + @classmethod + def load_json(cls, data, default_rule=None): + """ + Allow loading of JSON rule data. + """ + + # Suck in the JSON data and parse the rules + rules = dict((k, parse_rule(v)) for k, v in + jsonutils.loads(data).items()) + + return cls(rules, default_rule) + + def __init__(self, rules=None, default_rule=None): + """Initialize the Rules store.""" + + super(Rules, self).__init__(rules or {}) + self.default_rule = default_rule + + def __missing__(self, key): + """Implements the default rule handling.""" + + # If the default rule isn't actually defined, do something + # reasonably intelligent + if not self.default_rule or self.default_rule not in self: + raise KeyError(key) + + return self[self.default_rule] + + def __str__(self): + """Dumps a string representation of the rules.""" + + # Start by building the canonical strings for the rules + out_rules = {} + for key, value in self.items(): + # Use empty string for singleton TrueCheck instances + if isinstance(value, TrueCheck): + out_rules[key] = '' + else: + out_rules[key] = str(value) + + # Dump a pretty-printed JSON representation + return jsonutils.dumps(out_rules, indent=4) + + +# Really have to figure out a way to deprecate this +def set_rules(rules): + """Set the rules in use for policy checks.""" + + global _rules + + _rules = rules + + +# Ditto +def reset(): + """Clear the rules used for policy checks.""" + + global _rules + + _rules = None + + +def check(rule, target, creds, exc=None, *args, **kwargs): + """ + Checks authorization of a rule against the target and credentials. + + :param rule: The rule to evaluate. + :param target: As much information about the object being operated + on as possible, as a dictionary. + :param creds: As much information about the user performing the + action as possible, as a dictionary. + :param exc: Class of the exception to raise if the check fails. + Any remaining arguments passed to check() (both + positional and keyword arguments) will be passed to + the exception class. If exc is not provided, returns + False. + + :return: Returns False if the policy does not allow the action and + exc is not provided; otherwise, returns a value that + evaluates to True. Note: for rules using the "case" + expression, this True value will be the specified string + from the expression. + """ + + # Allow the rule to be a Check tree + if isinstance(rule, BaseCheck): + result = rule(target, creds) + elif not _rules: + # No rules to reference means we're going to fail closed + result = False + else: + try: + # Evaluate the rule + result = _rules[rule](target, creds) + except KeyError: + # If the rule doesn't exist, fail closed + result = False + + # If it is False, raise the exception if requested + if exc and result is False: + raise exc(*args, **kwargs) + + return result + + +class BaseCheck(object): + """ + Abstract base class for Check classes. + """ + + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def __str__(self): + """ + Retrieve a string representation of the Check tree rooted at + this node. + """ + + pass + + @abc.abstractmethod + def __call__(self, target, cred): + """ + Perform the check. Returns False to reject the access or a + true value (not necessary True) to accept the access. + """ + + pass + + +class FalseCheck(BaseCheck): + """ + A policy check that always returns False (disallow). + """ + + def __str__(self): + """Return a string representation of this check.""" + + return "!" + + def __call__(self, target, cred): + """Check the policy.""" + + return False + + +class TrueCheck(BaseCheck): + """ + A policy check that always returns True (allow). + """ + + def __str__(self): + """Return a string representation of this check.""" + + return "@" + + def __call__(self, target, cred): + """Check the policy.""" + + return True + + +class Check(BaseCheck): + """ + A base class to allow for user-defined policy checks. + """ + + def __init__(self, kind, match): + """ + :param kind: The kind of the check, i.e., the field before the + ':'. + :param match: The match of the check, i.e., the field after + the ':'. + """ + + self.kind = kind + self.match = match + + def __str__(self): + """Return a string representation of this check.""" + + return "%s:%s" % (self.kind, self.match) + + +class NotCheck(BaseCheck): + """ + A policy check that inverts the result of another policy check. + Implements the "not" operator. + """ + + def __init__(self, rule): + """ + Initialize the 'not' check. + + :param rule: The rule to negate. Must be a Check. + """ + + self.rule = rule + + def __str__(self): + """Return a string representation of this check.""" + + return "not %s" % self.rule + + def __call__(self, target, cred): + """ + Check the policy. Returns the logical inverse of the wrapped + check. + """ + + return not self.rule(target, cred) + + +class AndCheck(BaseCheck): + """ + A policy check that requires that a list of other checks all + return True. Implements the "and" operator. + """ + + def __init__(self, rules): + """ + Initialize the 'and' check. + + :param rules: A list of rules that will be tested. + """ + + self.rules = rules + + def __str__(self): + """Return a string representation of this check.""" + + return "(%s)" % ' and '.join(str(r) for r in self.rules) + + def __call__(self, target, cred): + """ + Check the policy. Requires that all rules accept in order to + return True. + """ + + for rule in self.rules: + if not rule(target, cred): + return False + + return True + + def add_check(self, rule): + """ + Allows addition of another rule to the list of rules that will + be tested. Returns the AndCheck object for convenience. + """ + + self.rules.append(rule) + return self + + +class OrCheck(BaseCheck): + """ + A policy check that requires that at least one of a list of other + checks returns True. Implements the "or" operator. + """ + + def __init__(self, rules): + """ + Initialize the 'or' check. + + :param rules: A list of rules that will be tested. + """ + + self.rules = rules + + def __str__(self): + """Return a string representation of this check.""" + + return "(%s)" % ' or '.join(str(r) for r in self.rules) + + def __call__(self, target, cred): + """ + Check the policy. Requires that at least one rule accept in + order to return True. + """ + + for rule in self.rules: + if rule(target, cred): + return True + + return False + + def add_check(self, rule): + """ + Allows addition of another rule to the list of rules that will + be tested. Returns the OrCheck object for convenience. + """ + + self.rules.append(rule) + return self + + +def _parse_check(rule): + """ + Parse a single base check rule into an appropriate Check object. + """ + + # Handle the special checks + if rule == '!': + return FalseCheck() + elif rule == '@': + return TrueCheck() + + try: + kind, match = rule.split(':', 1) + except Exception: + LOG.exception(_("Failed to understand rule %(rule)s") % locals()) + # If the rule is invalid, we'll fail closed + return FalseCheck() + + # Find what implements the check + if kind in _checks: + return _checks[kind](kind, match) + elif None in _checks: + return _checks[None](kind, match) + else: + LOG.error(_("No handler for matches of kind %s") % kind) + return FalseCheck() + + +def _parse_list_rule(rule): + """ + Provided for backwards compatibility. Translates the old + list-of-lists syntax into a tree of Check objects. + """ + + # Empty rule defaults to True + if not rule: + return TrueCheck() + + # Outer list is joined by "or"; inner list by "and" + or_list = [] + for inner_rule in rule: + # Elide empty inner lists + if not inner_rule: + continue + + # Handle bare strings + if isinstance(inner_rule, basestring): + inner_rule = [inner_rule] + + # Parse the inner rules into Check objects + and_list = [_parse_check(r) for r in inner_rule] + + # Append the appropriate check to the or_list + if len(and_list) == 1: + or_list.append(and_list[0]) + else: + or_list.append(AndCheck(and_list)) + + # If we have only one check, omit the "or" + if len(or_list) == 0: + return FalseCheck() + elif len(or_list) == 1: + return or_list[0] + + return OrCheck(or_list) + + +# Used for tokenizing the policy language +_tokenize_re = re.compile(r'\s+') + + +def _parse_tokenize(rule): + """ + Tokenizer for the policy language. + + Most of the single-character tokens are specified in the + _tokenize_re; however, parentheses need to be handled specially, + because they can appear inside a check string. Thankfully, those + parentheses that appear inside a check string can never occur at + the very beginning or end ("%(variable)s" is the correct syntax). + """ + + for tok in _tokenize_re.split(rule): + # Skip empty tokens + if not tok or tok.isspace(): + continue + + # Handle leading parens on the token + clean = tok.lstrip('(') + for i in range(len(tok) - len(clean)): + yield '(', '(' + + # If it was only parentheses, continue + if not clean: + continue + else: + tok = clean + + # Handle trailing parens on the token + clean = tok.rstrip(')') + trail = len(tok) - len(clean) + + # Yield the cleaned token + lowered = clean.lower() + if lowered in ('and', 'or', 'not'): + # Special tokens + yield lowered, clean + elif clean: + # Not a special token, but not composed solely of ')' + if len(tok) >= 2 and ((tok[0], tok[-1]) in + [('"', '"'), ("'", "'")]): + # It's a quoted string + yield 'string', tok[1:-1] + else: + yield 'check', _parse_check(clean) + + # Yield the trailing parens + for i in range(trail): + yield ')', ')' + + +class ParseStateMeta(type): + """ + Metaclass for the ParseState class. Facilitates identifying + reduction methods. + """ + + def __new__(mcs, name, bases, cls_dict): + """ + Create the class. Injects the 'reducers' list, a list of + tuples matching token sequences to the names of the + corresponding reduction methods. + """ + + reducers = [] + + for key, value in cls_dict.items(): + if not hasattr(value, 'reducers'): + continue + for reduction in value.reducers: + reducers.append((reduction, key)) + + cls_dict['reducers'] = reducers + + return super(ParseStateMeta, mcs).__new__(mcs, name, bases, cls_dict) + + +def reducer(*tokens): + """ + Decorator for reduction methods. Arguments are a sequence of + tokens, in order, which should trigger running this reduction + method. + """ + + def decorator(func): + # Make sure we have a list of reducer sequences + if not hasattr(func, 'reducers'): + func.reducers = [] + + # Add the tokens to the list of reducer sequences + func.reducers.append(list(tokens)) + + return func + + return decorator + + +class ParseState(object): + """ + Implement the core of parsing the policy language. Uses a greedy + reduction algorithm to reduce a sequence of tokens into a single + terminal, the value of which will be the root of the Check tree. + + Note: error reporting is rather lacking. The best we can get with + this parser formulation is an overall "parse failed" error. + Fortunately, the policy language is simple enough that this + shouldn't be that big a problem. + """ + + __metaclass__ = ParseStateMeta + + def __init__(self): + """Initialize the ParseState.""" + + self.tokens = [] + self.values = [] + + def reduce(self): + """ + Perform a greedy reduction of the token stream. If a reducer + method matches, it will be executed, then the reduce() method + will be called recursively to search for any more possible + reductions. + """ + + for reduction, methname in self.reducers: + if (len(self.tokens) >= len(reduction) and + self.tokens[-len(reduction):] == reduction): + # Get the reduction method + meth = getattr(self, methname) + + # Reduce the token stream + results = meth(*self.values[-len(reduction):]) + + # Update the tokens and values + self.tokens[-len(reduction):] = [r[0] for r in results] + self.values[-len(reduction):] = [r[1] for r in results] + + # Check for any more reductions + return self.reduce() + + def shift(self, tok, value): + """Adds one more token to the state. Calls reduce().""" + + self.tokens.append(tok) + self.values.append(value) + + # Do a greedy reduce... + self.reduce() + + @property + def result(self): + """ + Obtain the final result of the parse. Raises ValueError if + the parse failed to reduce to a single result. + """ + + if len(self.values) != 1: + raise ValueError("Could not parse rule") + return self.values[0] + + @reducer('(', 'check', ')') + @reducer('(', 'and_expr', ')') + @reducer('(', 'or_expr', ')') + def _wrap_check(self, _p1, check, _p2): + """Turn parenthesized expressions into a 'check' token.""" + + return [('check', check)] + + @reducer('check', 'and', 'check') + def _make_and_expr(self, check1, _and, check2): + """ + Create an 'and_expr' from two checks joined by the 'and' + operator. + """ + + return [('and_expr', AndCheck([check1, check2]))] + + @reducer('and_expr', 'and', 'check') + def _extend_and_expr(self, and_expr, _and, check): + """ + Extend an 'and_expr' by adding one more check. + """ + + return [('and_expr', and_expr.add_check(check))] + + @reducer('check', 'or', 'check') + def _make_or_expr(self, check1, _or, check2): + """ + Create an 'or_expr' from two checks joined by the 'or' + operator. + """ + + return [('or_expr', OrCheck([check1, check2]))] + + @reducer('or_expr', 'or', 'check') + def _extend_or_expr(self, or_expr, _or, check): + """ + Extend an 'or_expr' by adding one more check. + """ + + return [('or_expr', or_expr.add_check(check))] + + @reducer('not', 'check') + def _make_not_expr(self, _not, check): + """Invert the result of another check.""" + + return [('check', NotCheck(check))] + + +def _parse_text_rule(rule): + """ + Translates a policy written in the policy language into a tree of + Check objects. + """ + + # Empty rule means always accept + if not rule: + return TrueCheck() + + # Parse the token stream + state = ParseState() + for tok, value in _parse_tokenize(rule): + state.shift(tok, value) + + try: + return state.result + except ValueError: + # Couldn't parse the rule + LOG.exception(_("Failed to understand rule %(rule)r") % locals()) + + # Fail closed + return FalseCheck() + + +def parse_rule(rule): + """ + Parses a policy rule into a tree of Check objects. + """ + + # If the rule is a string, it's in the policy language + if isinstance(rule, basestring): + return _parse_text_rule(rule) + return _parse_list_rule(rule) + + +def register(name, func=None): + """ + Register a function or Check class as a policy check. + + :param name: Gives the name of the check type, e.g., 'rule', + 'role', etc. If name is None, a default check type + will be registered. + :param func: If given, provides the function or class to register. + If not given, returns a function taking one argument + to specify the function or class to register, + allowing use as a decorator. + """ + + # Perform the actual decoration by registering the function or + # class. Returns the function or class for compliance with the + # decorator interface. + def decorator(func): + _checks[name] = func + return func + + # If the function or class is given, do the registration + if func: + return decorator(func) + + return decorator + + +@register("rule") +class RuleCheck(Check): + def __call__(self, target, creds): + """ + Recursively checks credentials based on the defined rules. + """ + + try: + return _rules[self.match](target, creds) + except KeyError: + # We don't have any matching rule; fail closed + return False + + +@register("role") +class RoleCheck(Check): + def __call__(self, target, creds): + """Check that there is a matching role in the cred dict.""" + + return self.match.lower() in [x.lower() for x in creds['roles']] + + +@register('http') +class HttpCheck(Check): + def __call__(self, target, creds): + """ + Check http: rules by calling to a remote server. + + This example implementation simply verifies that the response + is exactly 'True'. + """ + + url = ('http:' + self.match) % target + data = {'target': jsonutils.dumps(target), + 'credentials': jsonutils.dumps(creds)} + post_data = urllib.urlencode(data) + f = urllib2.urlopen(url, post_data) + return f.read() == "True" + + +@register(None) +class GenericCheck(Check): + def __call__(self, target, creds): + """ + Check an individual match. + + Matches look like: + + tenant:%(tenant_id)s + role:compute:admin + """ + + # TODO(termie): do dict inspection via dot syntax + match = self.match % target + if self.kind in creds: + return match == six.text_type(creds[self.kind]) + return False diff --git a/staccato/openstack/common/processutils.py b/staccato/openstack/common/processutils.py new file mode 100644 index 0000000..9f1139c --- /dev/null +++ b/staccato/openstack/common/processutils.py @@ -0,0 +1,247 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +System-level utilities and helper functions. +""" + +import os +import random +import shlex +import signal + +from eventlet.green import subprocess +from eventlet import greenthread + +from staccato.openstack.common.gettextutils import _ +from staccato.openstack.common import log as logging + + +LOG = logging.getLogger(__name__) + + +class InvalidArgumentError(Exception): + def __init__(self, message=None): + super(InvalidArgumentError, self).__init__(message) + + +class UnknownArgumentError(Exception): + def __init__(self, message=None): + super(UnknownArgumentError, self).__init__(message) + + +class ProcessExecutionError(Exception): + def __init__(self, stdout=None, stderr=None, exit_code=None, cmd=None, + description=None): + self.exit_code = exit_code + self.stderr = stderr + self.stdout = stdout + self.cmd = cmd + self.description = description + + if description is None: + description = "Unexpected error while running command." + if exit_code is None: + exit_code = '-' + message = ("%s\nCommand: %s\nExit code: %s\nStdout: %r\nStderr: %r" + % (description, cmd, exit_code, stdout, stderr)) + super(ProcessExecutionError, self).__init__(message) + + +class NoRootWrapSpecified(Exception): + def __init__(self, message=None): + super(NoRootWrapSpecified, self).__init__(message) + + +def _subprocess_setup(): + # Python installs a SIGPIPE handler by default. This is usually not what + # non-Python subprocesses expect. + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + +def execute(*cmd, **kwargs): + """ + Helper method to shell out and execute a command through subprocess with + optional retry. + + :param cmd: Passed to subprocess.Popen. + :type cmd: string + :param process_input: Send to opened process. + :type proces_input: string + :param check_exit_code: Single bool, int, or list of allowed exit + codes. Defaults to [0]. Raise + :class:`ProcessExecutionError` unless + program exits with one of these code. + :type check_exit_code: boolean, int, or [int] + :param delay_on_retry: True | False. Defaults to True. If set to True, + wait a short amount of time before retrying. + :type delay_on_retry: boolean + :param attempts: How many times to retry cmd. + :type attempts: int + :param run_as_root: True | False. Defaults to False. If set to True, + the command is prefixed by the command specified + in the root_helper kwarg. + :type run_as_root: boolean + :param root_helper: command to prefix to commands called with + run_as_root=True + :type root_helper: string + :param shell: whether or not there should be a shell used to + execute this command. Defaults to false. + :type shell: boolean + :returns: (stdout, stderr) from process execution + :raises: :class:`UnknownArgumentError` on + receiving unknown arguments + :raises: :class:`ProcessExecutionError` + """ + + process_input = kwargs.pop('process_input', None) + check_exit_code = kwargs.pop('check_exit_code', [0]) + ignore_exit_code = False + delay_on_retry = kwargs.pop('delay_on_retry', True) + attempts = kwargs.pop('attempts', 1) + run_as_root = kwargs.pop('run_as_root', False) + root_helper = kwargs.pop('root_helper', '') + shell = kwargs.pop('shell', False) + + if isinstance(check_exit_code, bool): + ignore_exit_code = not check_exit_code + check_exit_code = [0] + elif isinstance(check_exit_code, int): + check_exit_code = [check_exit_code] + + if len(kwargs): + raise UnknownArgumentError(_('Got unknown keyword args ' + 'to utils.execute: %r') % kwargs) + + if run_as_root and os.geteuid() != 0: + if not root_helper: + raise NoRootWrapSpecified( + message=('Command requested root, but did not specify a root ' + 'helper.')) + cmd = shlex.split(root_helper) + list(cmd) + + cmd = map(str, cmd) + + while attempts > 0: + attempts -= 1 + try: + LOG.debug(_('Running cmd (subprocess): %s'), ' '.join(cmd)) + _PIPE = subprocess.PIPE # pylint: disable=E1101 + + if os.name == 'nt': + preexec_fn = None + close_fds = False + else: + preexec_fn = _subprocess_setup + close_fds = True + + obj = subprocess.Popen(cmd, + stdin=_PIPE, + stdout=_PIPE, + stderr=_PIPE, + close_fds=close_fds, + preexec_fn=preexec_fn, + shell=shell) + result = None + if process_input is not None: + result = obj.communicate(process_input) + else: + result = obj.communicate() + obj.stdin.close() # pylint: disable=E1101 + _returncode = obj.returncode # pylint: disable=E1101 + if _returncode: + LOG.debug(_('Result was %s') % _returncode) + if not ignore_exit_code and _returncode not in check_exit_code: + (stdout, stderr) = result + raise ProcessExecutionError(exit_code=_returncode, + stdout=stdout, + stderr=stderr, + cmd=' '.join(cmd)) + return result + except ProcessExecutionError: + if not attempts: + raise + else: + LOG.debug(_('%r failed. Retrying.'), cmd) + if delay_on_retry: + greenthread.sleep(random.randint(20, 200) / 100.0) + finally: + # NOTE(termie): this appears to be necessary to let the subprocess + # call clean something up in between calls, without + # it two execute calls in a row hangs the second one + greenthread.sleep(0) + + +def trycmd(*args, **kwargs): + """ + A wrapper around execute() to more easily handle warnings and errors. + + Returns an (out, err) tuple of strings containing the output of + the command's stdout and stderr. If 'err' is not empty then the + command can be considered to have failed. + + :discard_warnings True | False. Defaults to False. If set to True, + then for succeeding commands, stderr is cleared + + """ + discard_warnings = kwargs.pop('discard_warnings', False) + + try: + out, err = execute(*args, **kwargs) + failed = False + except ProcessExecutionError, exn: + out, err = '', str(exn) + failed = True + + if not failed and discard_warnings and err: + # Handle commands that output to stderr but otherwise succeed + err = '' + + return out, err + + +def ssh_execute(ssh, cmd, process_input=None, + addl_env=None, check_exit_code=True): + LOG.debug(_('Running cmd (SSH): %s'), cmd) + if addl_env: + raise InvalidArgumentError(_('Environment not supported over SSH')) + + if process_input: + # This is (probably) fixable if we need it... + raise InvalidArgumentError(_('process_input not supported over SSH')) + + stdin_stream, stdout_stream, stderr_stream = ssh.exec_command(cmd) + channel = stdout_stream.channel + + # NOTE(justinsb): This seems suspicious... + # ...other SSH clients have buffering issues with this approach + stdout = stdout_stream.read() + stderr = stderr_stream.read() + stdin_stream.close() + + exit_status = channel.recv_exit_status() + + # exit_status == -1 if no exit code was returned + if exit_status != -1: + LOG.debug(_('Result was %s') % exit_status) + if check_exit_code and exit_status != 0: + raise ProcessExecutionError(exit_code=exit_status, + stdout=stdout, + stderr=stderr, + cmd=cmd) + + return (stdout, stderr) diff --git a/staccato/openstack/common/rpc/__init__.py b/staccato/openstack/common/rpc/__init__.py new file mode 100644 index 0000000..91e116d --- /dev/null +++ b/staccato/openstack/common/rpc/__init__.py @@ -0,0 +1,307 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# Copyright 2011 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +A remote procedure call (rpc) abstraction. + +For some wrappers that add message versioning to rpc, see: + rpc.dispatcher + rpc.proxy +""" + +import inspect + +from oslo.config import cfg + +from staccato.openstack.common.gettextutils import _ +from staccato.openstack.common import importutils +from staccato.openstack.common import local +from staccato.openstack.common import log as logging + + +LOG = logging.getLogger(__name__) + + +rpc_opts = [ + cfg.StrOpt('rpc_backend', + default='%s.impl_kombu' % __package__, + help="The messaging module to use, defaults to kombu."), + cfg.IntOpt('rpc_thread_pool_size', + default=64, + help='Size of RPC thread pool'), + cfg.IntOpt('rpc_conn_pool_size', + default=30, + help='Size of RPC connection pool'), + cfg.IntOpt('rpc_response_timeout', + default=60, + help='Seconds to wait for a response from call or multicall'), + cfg.IntOpt('rpc_cast_timeout', + default=30, + help='Seconds to wait before a cast expires (TTL). ' + 'Only supported by impl_zmq.'), + cfg.ListOpt('allowed_rpc_exception_modules', + default=['staccato.openstack.common.exception', + 'nova.exception', + 'cinder.exception', + 'exceptions', + ], + help='Modules of exceptions that are permitted to be recreated' + 'upon receiving exception data from an rpc call.'), + cfg.BoolOpt('fake_rabbit', + default=False, + help='If passed, use a fake RabbitMQ provider'), + cfg.StrOpt('control_exchange', + default='openstack', + help='AMQP exchange to connect to if using RabbitMQ or Qpid'), +] + +CONF = cfg.CONF +CONF.register_opts(rpc_opts) + + +def set_defaults(control_exchange): + cfg.set_defaults(rpc_opts, + control_exchange=control_exchange) + + +def create_connection(new=True): + """Create a connection to the message bus used for rpc. + + For some example usage of creating a connection and some consumers on that + connection, see nova.service. + + :param new: Whether or not to create a new connection. A new connection + will be created by default. If new is False, the + implementation is free to return an existing connection from a + pool. + + :returns: An instance of openstack.common.rpc.common.Connection + """ + return _get_impl().create_connection(CONF, new=new) + + +def _check_for_lock(): + if not CONF.debug: + return None + + if ((hasattr(local.strong_store, 'locks_held') + and local.strong_store.locks_held)): + stack = ' :: '.join([frame[3] for frame in inspect.stack()]) + LOG.warn(_('A RPC is being made while holding a lock. The locks ' + 'currently held are %(locks)s. This is probably a bug. ' + 'Please report it. Include the following: [%(stack)s].'), + {'locks': local.strong_store.locks_held, + 'stack': stack}) + return True + + return False + + +def call(context, topic, msg, timeout=None, check_for_lock=False): + """Invoke a remote method that returns something. + + :param context: Information that identifies the user that has made this + request. + :param topic: The topic to send the rpc message to. This correlates to the + topic argument of + openstack.common.rpc.common.Connection.create_consumer() + and only applies when the consumer was created with + fanout=False. + :param msg: This is a dict in the form { "method" : "method_to_invoke", + "args" : dict_of_kwargs } + :param timeout: int, number of seconds to use for a response timeout. + If set, this overrides the rpc_response_timeout option. + :param check_for_lock: if True, a warning is emitted if a RPC call is made + with a lock held. + + :returns: A dict from the remote method. + + :raises: openstack.common.rpc.common.Timeout if a complete response + is not received before the timeout is reached. + """ + if check_for_lock: + _check_for_lock() + return _get_impl().call(CONF, context, topic, msg, timeout) + + +def cast(context, topic, msg): + """Invoke a remote method that does not return anything. + + :param context: Information that identifies the user that has made this + request. + :param topic: The topic to send the rpc message to. This correlates to the + topic argument of + openstack.common.rpc.common.Connection.create_consumer() + and only applies when the consumer was created with + fanout=False. + :param msg: This is a dict in the form { "method" : "method_to_invoke", + "args" : dict_of_kwargs } + + :returns: None + """ + return _get_impl().cast(CONF, context, topic, msg) + + +def fanout_cast(context, topic, msg): + """Broadcast a remote method invocation with no return. + + This method will get invoked on all consumers that were set up with this + topic name and fanout=True. + + :param context: Information that identifies the user that has made this + request. + :param topic: The topic to send the rpc message to. This correlates to the + topic argument of + openstack.common.rpc.common.Connection.create_consumer() + and only applies when the consumer was created with + fanout=True. + :param msg: This is a dict in the form { "method" : "method_to_invoke", + "args" : dict_of_kwargs } + + :returns: None + """ + return _get_impl().fanout_cast(CONF, context, topic, msg) + + +def multicall(context, topic, msg, timeout=None, check_for_lock=False): + """Invoke a remote method and get back an iterator. + + In this case, the remote method will be returning multiple values in + separate messages, so the return values can be processed as the come in via + an iterator. + + :param context: Information that identifies the user that has made this + request. + :param topic: The topic to send the rpc message to. This correlates to the + topic argument of + openstack.common.rpc.common.Connection.create_consumer() + and only applies when the consumer was created with + fanout=False. + :param msg: This is a dict in the form { "method" : "method_to_invoke", + "args" : dict_of_kwargs } + :param timeout: int, number of seconds to use for a response timeout. + If set, this overrides the rpc_response_timeout option. + :param check_for_lock: if True, a warning is emitted if a RPC call is made + with a lock held. + + :returns: An iterator. The iterator will yield a tuple (N, X) where N is + an index that starts at 0 and increases by one for each value + returned and X is the Nth value that was returned by the remote + method. + + :raises: openstack.common.rpc.common.Timeout if a complete response + is not received before the timeout is reached. + """ + if check_for_lock: + _check_for_lock() + return _get_impl().multicall(CONF, context, topic, msg, timeout) + + +def notify(context, topic, msg, envelope=False): + """Send notification event. + + :param context: Information that identifies the user that has made this + request. + :param topic: The topic to send the notification to. + :param msg: This is a dict of content of event. + :param envelope: Set to True to enable message envelope for notifications. + + :returns: None + """ + return _get_impl().notify(cfg.CONF, context, topic, msg, envelope) + + +def cleanup(): + """Clean up resoruces in use by implementation. + + Clean up any resources that have been allocated by the RPC implementation. + This is typically open connections to a messaging service. This function + would get called before an application using this API exits to allow + connections to get torn down cleanly. + + :returns: None + """ + return _get_impl().cleanup() + + +def cast_to_server(context, server_params, topic, msg): + """Invoke a remote method that does not return anything. + + :param context: Information that identifies the user that has made this + request. + :param server_params: Connection information + :param topic: The topic to send the notification to. + :param msg: This is a dict in the form { "method" : "method_to_invoke", + "args" : dict_of_kwargs } + + :returns: None + """ + return _get_impl().cast_to_server(CONF, context, server_params, topic, + msg) + + +def fanout_cast_to_server(context, server_params, topic, msg): + """Broadcast to a remote method invocation with no return. + + :param context: Information that identifies the user that has made this + request. + :param server_params: Connection information + :param topic: The topic to send the notification to. + :param msg: This is a dict in the form { "method" : "method_to_invoke", + "args" : dict_of_kwargs } + + :returns: None + """ + return _get_impl().fanout_cast_to_server(CONF, context, server_params, + topic, msg) + + +def queue_get_for(context, topic, host): + """Get a queue name for a given topic + host. + + This function only works if this naming convention is followed on the + consumer side, as well. For example, in nova, every instance of the + nova-foo service calls create_consumer() for two topics: + + foo + foo. + + Messages sent to the 'foo' topic are distributed to exactly one instance of + the nova-foo service. The services are chosen in a round-robin fashion. + Messages sent to the 'foo.' topic are sent to the nova-foo service on + . + """ + return '%s.%s' % (topic, host) if host else topic + + +_RPCIMPL = None + + +def _get_impl(): + """Delay import of rpc_backend until configuration is loaded.""" + global _RPCIMPL + if _RPCIMPL is None: + try: + _RPCIMPL = importutils.import_module(CONF.rpc_backend) + except ImportError: + # For backwards compatibility with older nova config. + impl = CONF.rpc_backend.replace('nova.rpc', + 'nova.openstack.common.rpc') + _RPCIMPL = importutils.import_module(impl) + return _RPCIMPL diff --git a/staccato/openstack/common/rpc/amqp.py b/staccato/openstack/common/rpc/amqp.py new file mode 100644 index 0000000..b228c48 --- /dev/null +++ b/staccato/openstack/common/rpc/amqp.py @@ -0,0 +1,677 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# Copyright 2011 - 2012, Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Shared code between AMQP based openstack.common.rpc implementations. + +The code in this module is shared between the rpc implemenations based on AMQP. +Specifically, this includes impl_kombu and impl_qpid. impl_carrot also uses +AMQP, but is deprecated and predates this code. +""" + +import collections +import inspect +import sys +import uuid + +from eventlet import greenpool +from eventlet import pools +from eventlet import queue +from eventlet import semaphore +# TODO(pekowsk): Remove import cfg and below comment in Havana. +# This import should no longer be needed when the amqp_rpc_single_reply_queue +# option is removed. +from oslo.config import cfg + +from staccato.openstack.common import excutils +from staccato.openstack.common.gettextutils import _ +from staccato.openstack.common import local +from staccato.openstack.common import log as logging +from staccato.openstack.common.rpc import common as rpc_common + + +# TODO(pekowski): Remove this option in Havana. +amqp_opts = [ + cfg.BoolOpt('amqp_rpc_single_reply_queue', + default=False, + help='Enable a fast single reply queue if using AMQP based ' + 'RPC like RabbitMQ or Qpid.'), +] + +cfg.CONF.register_opts(amqp_opts) + +UNIQUE_ID = '_unique_id' +LOG = logging.getLogger(__name__) + + +class Pool(pools.Pool): + """Class that implements a Pool of Connections.""" + def __init__(self, conf, connection_cls, *args, **kwargs): + self.connection_cls = connection_cls + self.conf = conf + kwargs.setdefault("max_size", self.conf.rpc_conn_pool_size) + kwargs.setdefault("order_as_stack", True) + super(Pool, self).__init__(*args, **kwargs) + self.reply_proxy = None + + # TODO(comstud): Timeout connections not used in a while + def create(self): + LOG.debug(_('Pool creating new connection')) + return self.connection_cls(self.conf) + + def empty(self): + while self.free_items: + self.get().close() + # Force a new connection pool to be created. + # Note that this was added due to failing unit test cases. The issue + # is the above "while loop" gets all the cached connections from the + # pool and closes them, but never returns them to the pool, a pool + # leak. The unit tests hang waiting for an item to be returned to the + # pool. The unit tests get here via the teatDown() method. In the run + # time code, it gets here via cleanup() and only appears in service.py + # just before doing a sys.exit(), so cleanup() only happens once and + # the leakage is not a problem. + self.connection_cls.pool = None + + +_pool_create_sem = semaphore.Semaphore() + + +def get_connection_pool(conf, connection_cls): + with _pool_create_sem: + # Make sure only one thread tries to create the connection pool. + if not connection_cls.pool: + connection_cls.pool = Pool(conf, connection_cls) + return connection_cls.pool + + +class ConnectionContext(rpc_common.Connection): + """The class that is actually returned to the caller of + create_connection(). This is essentially a wrapper around + Connection that supports 'with'. It can also return a new + Connection, or one from a pool. The function will also catch + when an instance of this class is to be deleted. With that + we can return Connections to the pool on exceptions and so + forth without making the caller be responsible for catching + them. If possible the function makes sure to return a + connection to the pool. + """ + + def __init__(self, conf, connection_pool, pooled=True, server_params=None): + """Create a new connection, or get one from the pool""" + self.connection = None + self.conf = conf + self.connection_pool = connection_pool + if pooled: + self.connection = connection_pool.get() + else: + self.connection = connection_pool.connection_cls( + conf, + server_params=server_params) + self.pooled = pooled + + def __enter__(self): + """When with ConnectionContext() is used, return self""" + return self + + def _done(self): + """If the connection came from a pool, clean it up and put it back. + If it did not come from a pool, close it. + """ + if self.connection: + if self.pooled: + # Reset the connection so it's ready for the next caller + # to grab from the pool + self.connection.reset() + self.connection_pool.put(self.connection) + else: + try: + self.connection.close() + except Exception: + pass + self.connection = None + + def __exit__(self, exc_type, exc_value, tb): + """End of 'with' statement. We're done here.""" + self._done() + + def __del__(self): + """Caller is done with this connection. Make sure we cleaned up.""" + self._done() + + def close(self): + """Caller is done with this connection.""" + self._done() + + def create_consumer(self, topic, proxy, fanout=False): + self.connection.create_consumer(topic, proxy, fanout) + + def create_worker(self, topic, proxy, pool_name): + self.connection.create_worker(topic, proxy, pool_name) + + def join_consumer_pool(self, callback, pool_name, topic, exchange_name): + self.connection.join_consumer_pool(callback, + pool_name, + topic, + exchange_name) + + def consume_in_thread(self): + self.connection.consume_in_thread() + + def __getattr__(self, key): + """Proxy all other calls to the Connection instance""" + if self.connection: + return getattr(self.connection, key) + else: + raise rpc_common.InvalidRPCConnectionReuse() + + +class ReplyProxy(ConnectionContext): + """ Connection class for RPC replies / callbacks """ + def __init__(self, conf, connection_pool): + self._call_waiters = {} + self._num_call_waiters = 0 + self._num_call_waiters_wrn_threshhold = 10 + self._reply_q = 'reply_' + uuid.uuid4().hex + super(ReplyProxy, self).__init__(conf, connection_pool, pooled=False) + self.declare_direct_consumer(self._reply_q, self._process_data) + self.consume_in_thread() + + def _process_data(self, message_data): + msg_id = message_data.pop('_msg_id', None) + waiter = self._call_waiters.get(msg_id) + if not waiter: + LOG.warn(_('no calling threads waiting for msg_id : %s' + ', message : %s') % (msg_id, message_data)) + else: + waiter.put(message_data) + + def add_call_waiter(self, waiter, msg_id): + self._num_call_waiters += 1 + if self._num_call_waiters > self._num_call_waiters_wrn_threshhold: + LOG.warn(_('Number of call waiters is greater than warning ' + 'threshhold: %d. There could be a MulticallProxyWaiter ' + 'leak.') % self._num_call_waiters_wrn_threshhold) + self._num_call_waiters_wrn_threshhold *= 2 + self._call_waiters[msg_id] = waiter + + def del_call_waiter(self, msg_id): + self._num_call_waiters -= 1 + del self._call_waiters[msg_id] + + def get_reply_q(self): + return self._reply_q + + +def msg_reply(conf, msg_id, reply_q, connection_pool, reply=None, + failure=None, ending=False, log_failure=True): + """Sends a reply or an error on the channel signified by msg_id. + + Failure should be a sys.exc_info() tuple. + + """ + with ConnectionContext(conf, connection_pool) as conn: + if failure: + failure = rpc_common.serialize_remote_exception(failure, + log_failure) + + try: + msg = {'result': reply, 'failure': failure} + except TypeError: + msg = {'result': dict((k, repr(v)) + for k, v in reply.__dict__.iteritems()), + 'failure': failure} + if ending: + msg['ending'] = True + _add_unique_id(msg) + # If a reply_q exists, add the msg_id to the reply and pass the + # reply_q to direct_send() to use it as the response queue. + # Otherwise use the msg_id for backward compatibilty. + if reply_q: + msg['_msg_id'] = msg_id + conn.direct_send(reply_q, rpc_common.serialize_msg(msg)) + else: + conn.direct_send(msg_id, rpc_common.serialize_msg(msg)) + + +class RpcContext(rpc_common.CommonRpcContext): + """Context that supports replying to a rpc.call""" + def __init__(self, **kwargs): + self.msg_id = kwargs.pop('msg_id', None) + self.reply_q = kwargs.pop('reply_q', None) + self.conf = kwargs.pop('conf') + super(RpcContext, self).__init__(**kwargs) + + def deepcopy(self): + values = self.to_dict() + values['conf'] = self.conf + values['msg_id'] = self.msg_id + values['reply_q'] = self.reply_q + return self.__class__(**values) + + def reply(self, reply=None, failure=None, ending=False, + connection_pool=None, log_failure=True): + if self.msg_id: + msg_reply(self.conf, self.msg_id, self.reply_q, connection_pool, + reply, failure, ending, log_failure) + if ending: + self.msg_id = None + + +def unpack_context(conf, msg): + """Unpack context from msg.""" + context_dict = {} + for key in list(msg.keys()): + # NOTE(vish): Some versions of python don't like unicode keys + # in kwargs. + key = str(key) + if key.startswith('_context_'): + value = msg.pop(key) + context_dict[key[9:]] = value + context_dict['msg_id'] = msg.pop('_msg_id', None) + context_dict['reply_q'] = msg.pop('_reply_q', None) + context_dict['conf'] = conf + ctx = RpcContext.from_dict(context_dict) + rpc_common._safe_log(LOG.debug, _('unpacked context: %s'), ctx.to_dict()) + return ctx + + +def pack_context(msg, context): + """Pack context into msg. + + Values for message keys need to be less than 255 chars, so we pull + context out into a bunch of separate keys. If we want to support + more arguments in rabbit messages, we may want to do the same + for args at some point. + + """ + context_d = dict([('_context_%s' % key, value) + for (key, value) in context.to_dict().iteritems()]) + msg.update(context_d) + + +class _MsgIdCache(object): + """This class checks any duplicate messages.""" + + # NOTE: This value is considered can be a configuration item, but + # it is not necessary to change its value in most cases, + # so let this value as static for now. + DUP_MSG_CHECK_SIZE = 16 + + def __init__(self, **kwargs): + self.prev_msgids = collections.deque([], + maxlen=self.DUP_MSG_CHECK_SIZE) + + def check_duplicate_message(self, message_data): + """AMQP consumers may read same message twice when exceptions occur + before ack is returned. This method prevents doing it. + """ + if UNIQUE_ID in message_data: + msg_id = message_data[UNIQUE_ID] + if msg_id not in self.prev_msgids: + self.prev_msgids.append(msg_id) + else: + raise rpc_common.DuplicateMessageError(msg_id=msg_id) + + +def _add_unique_id(msg): + """Add unique_id for checking duplicate messages.""" + unique_id = uuid.uuid4().hex + msg.update({UNIQUE_ID: unique_id}) + LOG.debug(_('UNIQUE_ID is %s.') % (unique_id)) + + +class _ThreadPoolWithWait(object): + """Base class for a delayed invocation manager used by + the Connection class to start up green threads + to handle incoming messages. + """ + + def __init__(self, conf, connection_pool): + self.pool = greenpool.GreenPool(conf.rpc_thread_pool_size) + self.connection_pool = connection_pool + self.conf = conf + + def wait(self): + """Wait for all callback threads to exit.""" + self.pool.waitall() + + +class CallbackWrapper(_ThreadPoolWithWait): + """Wraps a straight callback to allow it to be invoked in a green + thread. + """ + + def __init__(self, conf, callback, connection_pool): + """ + :param conf: cfg.CONF instance + :param callback: a callable (probably a function) + :param connection_pool: connection pool as returned by + get_connection_pool() + """ + super(CallbackWrapper, self).__init__( + conf=conf, + connection_pool=connection_pool, + ) + self.callback = callback + + def __call__(self, message_data): + self.pool.spawn_n(self.callback, message_data) + + +class ProxyCallback(_ThreadPoolWithWait): + """Calls methods on a proxy object based on method and args.""" + + def __init__(self, conf, proxy, connection_pool): + super(ProxyCallback, self).__init__( + conf=conf, + connection_pool=connection_pool, + ) + self.proxy = proxy + self.msg_id_cache = _MsgIdCache() + + def __call__(self, message_data): + """Consumer callback to call a method on a proxy object. + + Parses the message for validity and fires off a thread to call the + proxy object method. + + Message data should be a dictionary with two keys: + method: string representing the method to call + args: dictionary of arg: value + + Example: {'method': 'echo', 'args': {'value': 42}} + + """ + # It is important to clear the context here, because at this point + # the previous context is stored in local.store.context + if hasattr(local.store, 'context'): + del local.store.context + rpc_common._safe_log(LOG.debug, _('received %s'), message_data) + self.msg_id_cache.check_duplicate_message(message_data) + ctxt = unpack_context(self.conf, message_data) + method = message_data.get('method') + args = message_data.get('args', {}) + version = message_data.get('version') + namespace = message_data.get('namespace') + if not method: + LOG.warn(_('no method for message: %s') % message_data) + ctxt.reply(_('No method for message: %s') % message_data, + connection_pool=self.connection_pool) + return + self.pool.spawn_n(self._process_data, ctxt, version, method, + namespace, args) + + def _process_data(self, ctxt, version, method, namespace, args): + """Process a message in a new thread. + + If the proxy object we have has a dispatch method + (see rpc.dispatcher.RpcDispatcher), pass it the version, + method, and args and let it dispatch as appropriate. If not, use + the old behavior of magically calling the specified method on the + proxy we have here. + """ + ctxt.update_store() + try: + rval = self.proxy.dispatch(ctxt, version, method, namespace, + **args) + # Check if the result was a generator + if inspect.isgenerator(rval): + for x in rval: + ctxt.reply(x, None, connection_pool=self.connection_pool) + else: + ctxt.reply(rval, None, connection_pool=self.connection_pool) + # This final None tells multicall that it is done. + ctxt.reply(ending=True, connection_pool=self.connection_pool) + except rpc_common.ClientException as e: + LOG.debug(_('Expected exception during message handling (%s)') % + e._exc_info[1]) + ctxt.reply(None, e._exc_info, + connection_pool=self.connection_pool, + log_failure=False) + except Exception: + # sys.exc_info() is deleted by LOG.exception(). + exc_info = sys.exc_info() + LOG.error(_('Exception during message handling'), + exc_info=exc_info) + ctxt.reply(None, exc_info, connection_pool=self.connection_pool) + + +class MulticallProxyWaiter(object): + def __init__(self, conf, msg_id, timeout, connection_pool): + self._msg_id = msg_id + self._timeout = timeout or conf.rpc_response_timeout + self._reply_proxy = connection_pool.reply_proxy + self._done = False + self._got_ending = False + self._conf = conf + self._dataqueue = queue.LightQueue() + # Add this caller to the reply proxy's call_waiters + self._reply_proxy.add_call_waiter(self, self._msg_id) + self.msg_id_cache = _MsgIdCache() + + def put(self, data): + self._dataqueue.put(data) + + def done(self): + if self._done: + return + self._done = True + # Remove this caller from reply proxy's call_waiters + self._reply_proxy.del_call_waiter(self._msg_id) + + def _process_data(self, data): + result = None + self.msg_id_cache.check_duplicate_message(data) + if data['failure']: + failure = data['failure'] + result = rpc_common.deserialize_remote_exception(self._conf, + failure) + elif data.get('ending', False): + self._got_ending = True + else: + result = data['result'] + return result + + def __iter__(self): + """Return a result until we get a reply with an 'ending" flag""" + if self._done: + raise StopIteration + while True: + try: + data = self._dataqueue.get(timeout=self._timeout) + result = self._process_data(data) + except queue.Empty: + self.done() + raise rpc_common.Timeout() + except Exception: + with excutils.save_and_reraise_exception(): + self.done() + if self._got_ending: + self.done() + raise StopIteration + if isinstance(result, Exception): + self.done() + raise result + yield result + + +#TODO(pekowski): Remove MulticallWaiter() in Havana. +class MulticallWaiter(object): + def __init__(self, conf, connection, timeout): + self._connection = connection + self._iterator = connection.iterconsume(timeout=timeout or + conf.rpc_response_timeout) + self._result = None + self._done = False + self._got_ending = False + self._conf = conf + self.msg_id_cache = _MsgIdCache() + + def done(self): + if self._done: + return + self._done = True + self._iterator.close() + self._iterator = None + self._connection.close() + + def __call__(self, data): + """The consume() callback will call this. Store the result.""" + self.msg_id_cache.check_duplicate_message(data) + if data['failure']: + failure = data['failure'] + self._result = rpc_common.deserialize_remote_exception(self._conf, + failure) + + elif data.get('ending', False): + self._got_ending = True + else: + self._result = data['result'] + + def __iter__(self): + """Return a result until we get a 'None' response from consumer""" + if self._done: + raise StopIteration + while True: + try: + self._iterator.next() + except Exception: + with excutils.save_and_reraise_exception(): + self.done() + if self._got_ending: + self.done() + raise StopIteration + result = self._result + if isinstance(result, Exception): + self.done() + raise result + yield result + + +def create_connection(conf, new, connection_pool): + """Create a connection""" + return ConnectionContext(conf, connection_pool, pooled=not new) + + +_reply_proxy_create_sem = semaphore.Semaphore() + + +def multicall(conf, context, topic, msg, timeout, connection_pool): + """Make a call that returns multiple times.""" + # TODO(pekowski): Remove all these comments in Havana. + # For amqp_rpc_single_reply_queue = False, + # Can't use 'with' for multicall, as it returns an iterator + # that will continue to use the connection. When it's done, + # connection.close() will get called which will put it back into + # the pool + # For amqp_rpc_single_reply_queue = True, + # The 'with' statement is mandatory for closing the connection + LOG.debug(_('Making synchronous call on %s ...'), topic) + msg_id = uuid.uuid4().hex + msg.update({'_msg_id': msg_id}) + LOG.debug(_('MSG_ID is %s') % (msg_id)) + _add_unique_id(msg) + pack_context(msg, context) + + # TODO(pekowski): Remove this flag and the code under the if clause + # in Havana. + if not conf.amqp_rpc_single_reply_queue: + conn = ConnectionContext(conf, connection_pool) + wait_msg = MulticallWaiter(conf, conn, timeout) + conn.declare_direct_consumer(msg_id, wait_msg) + conn.topic_send(topic, rpc_common.serialize_msg(msg), timeout) + else: + with _reply_proxy_create_sem: + if not connection_pool.reply_proxy: + connection_pool.reply_proxy = ReplyProxy(conf, connection_pool) + msg.update({'_reply_q': connection_pool.reply_proxy.get_reply_q()}) + wait_msg = MulticallProxyWaiter(conf, msg_id, timeout, connection_pool) + with ConnectionContext(conf, connection_pool) as conn: + conn.topic_send(topic, rpc_common.serialize_msg(msg), timeout) + return wait_msg + + +def call(conf, context, topic, msg, timeout, connection_pool): + """Sends a message on a topic and wait for a response.""" + rv = multicall(conf, context, topic, msg, timeout, connection_pool) + # NOTE(vish): return the last result from the multicall + rv = list(rv) + if not rv: + return + return rv[-1] + + +def cast(conf, context, topic, msg, connection_pool): + """Sends a message on a topic without waiting for a response.""" + LOG.debug(_('Making asynchronous cast on %s...'), topic) + _add_unique_id(msg) + pack_context(msg, context) + with ConnectionContext(conf, connection_pool) as conn: + conn.topic_send(topic, rpc_common.serialize_msg(msg)) + + +def fanout_cast(conf, context, topic, msg, connection_pool): + """Sends a message on a fanout exchange without waiting for a response.""" + LOG.debug(_('Making asynchronous fanout cast...')) + _add_unique_id(msg) + pack_context(msg, context) + with ConnectionContext(conf, connection_pool) as conn: + conn.fanout_send(topic, rpc_common.serialize_msg(msg)) + + +def cast_to_server(conf, context, server_params, topic, msg, connection_pool): + """Sends a message on a topic to a specific server.""" + _add_unique_id(msg) + pack_context(msg, context) + with ConnectionContext(conf, connection_pool, pooled=False, + server_params=server_params) as conn: + conn.topic_send(topic, rpc_common.serialize_msg(msg)) + + +def fanout_cast_to_server(conf, context, server_params, topic, msg, + connection_pool): + """Sends a message on a fanout exchange to a specific server.""" + _add_unique_id(msg) + pack_context(msg, context) + with ConnectionContext(conf, connection_pool, pooled=False, + server_params=server_params) as conn: + conn.fanout_send(topic, rpc_common.serialize_msg(msg)) + + +def notify(conf, context, topic, msg, connection_pool, envelope): + """Sends a notification event on a topic.""" + LOG.debug(_('Sending %(event_type)s on %(topic)s'), + dict(event_type=msg.get('event_type'), + topic=topic)) + _add_unique_id(msg) + pack_context(msg, context) + with ConnectionContext(conf, connection_pool) as conn: + if envelope: + msg = rpc_common.serialize_msg(msg) + conn.notify_send(topic, msg) + + +def cleanup(connection_pool): + if connection_pool: + connection_pool.empty() + + +def get_control_exchange(conf): + return conf.control_exchange diff --git a/staccato/openstack/common/rpc/common.py b/staccato/openstack/common/rpc/common.py new file mode 100644 index 0000000..f0babac --- /dev/null +++ b/staccato/openstack/common/rpc/common.py @@ -0,0 +1,510 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# Copyright 2011 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import copy +import sys +import traceback + +from oslo.config import cfg +import six + +from staccato.openstack.common.gettextutils import _ +from staccato.openstack.common import importutils +from staccato.openstack.common import jsonutils +from staccato.openstack.common import local +from staccato.openstack.common import log as logging + + +CONF = cfg.CONF +LOG = logging.getLogger(__name__) + + +'''RPC Envelope Version. + +This version number applies to the top level structure of messages sent out. +It does *not* apply to the message payload, which must be versioned +independently. For example, when using rpc APIs, a version number is applied +for changes to the API being exposed over rpc. This version number is handled +in the rpc proxy and dispatcher modules. + +This version number applies to the message envelope that is used in the +serialization done inside the rpc layer. See serialize_msg() and +deserialize_msg(). + +The current message format (version 2.0) is very simple. It is: + + { + 'oslo.version': , + 'oslo.message': + } + +Message format version '1.0' is just considered to be the messages we sent +without a message envelope. + +So, the current message envelope just includes the envelope version. It may +eventually contain additional information, such as a signature for the message +payload. + +We will JSON encode the application message payload. The message envelope, +which includes the JSON encoded application message body, will be passed down +to the messaging libraries as a dict. +''' +_RPC_ENVELOPE_VERSION = '2.0' + +_VERSION_KEY = 'oslo.version' +_MESSAGE_KEY = 'oslo.message' + + +class RPCException(Exception): + message = _("An unknown RPC related exception occurred.") + + def __init__(self, message=None, **kwargs): + self.kwargs = kwargs + + if not message: + try: + message = self.message % kwargs + + except Exception: + # kwargs doesn't match a variable in the message + # log the issue and the kwargs + LOG.exception(_('Exception in string format operation')) + for name, value in kwargs.iteritems(): + LOG.error("%s: %s" % (name, value)) + # at least get the core message out if something happened + message = self.message + + super(RPCException, self).__init__(message) + + +class RemoteError(RPCException): + """Signifies that a remote class has raised an exception. + + Contains a string representation of the type of the original exception, + the value of the original exception, and the traceback. These are + sent to the parent as a joined string so printing the exception + contains all of the relevant info. + + """ + message = _("Remote error: %(exc_type)s %(value)s\n%(traceback)s.") + + def __init__(self, exc_type=None, value=None, traceback=None): + self.exc_type = exc_type + self.value = value + self.traceback = traceback + super(RemoteError, self).__init__(exc_type=exc_type, + value=value, + traceback=traceback) + + +class Timeout(RPCException): + """Signifies that a timeout has occurred. + + This exception is raised if the rpc_response_timeout is reached while + waiting for a response from the remote side. + """ + message = _('Timeout while waiting on RPC response - ' + 'topic: "%(topic)s", RPC method: "%(method)s" ' + 'info: "%(info)s"') + + def __init__(self, info=None, topic=None, method=None): + """ + :param info: Extra info to convey to the user + :param topic: The topic that the rpc call was sent to + :param rpc_method_name: The name of the rpc method being + called + """ + self.info = info + self.topic = topic + self.method = method + super(Timeout, self).__init__( + None, + info=info or _(''), + topic=topic or _(''), + method=method or _('')) + + +class DuplicateMessageError(RPCException): + message = _("Found duplicate message(%(msg_id)s). Skipping it.") + + +class InvalidRPCConnectionReuse(RPCException): + message = _("Invalid reuse of an RPC connection.") + + +class UnsupportedRpcVersion(RPCException): + message = _("Specified RPC version, %(version)s, not supported by " + "this endpoint.") + + +class UnsupportedRpcEnvelopeVersion(RPCException): + message = _("Specified RPC envelope version, %(version)s, " + "not supported by this endpoint.") + + +class Connection(object): + """A connection, returned by rpc.create_connection(). + + This class represents a connection to the message bus used for rpc. + An instance of this class should never be created by users of the rpc API. + Use rpc.create_connection() instead. + """ + def close(self): + """Close the connection. + + This method must be called when the connection will no longer be used. + It will ensure that any resources associated with the connection, such + as a network connection, and cleaned up. + """ + raise NotImplementedError() + + def create_consumer(self, topic, proxy, fanout=False): + """Create a consumer on this connection. + + A consumer is associated with a message queue on the backend message + bus. The consumer will read messages from the queue, unpack them, and + dispatch them to the proxy object. The contents of the message pulled + off of the queue will determine which method gets called on the proxy + object. + + :param topic: This is a name associated with what to consume from. + Multiple instances of a service may consume from the same + topic. For example, all instances of nova-compute consume + from a queue called "compute". In that case, the + messages will get distributed amongst the consumers in a + round-robin fashion if fanout=False. If fanout=True, + every consumer associated with this topic will get a + copy of every message. + :param proxy: The object that will handle all incoming messages. + :param fanout: Whether or not this is a fanout topic. See the + documentation for the topic parameter for some + additional comments on this. + """ + raise NotImplementedError() + + def create_worker(self, topic, proxy, pool_name): + """Create a worker on this connection. + + A worker is like a regular consumer of messages directed to a + topic, except that it is part of a set of such consumers (the + "pool") which may run in parallel. Every pool of workers will + receive a given message, but only one worker in the pool will + be asked to process it. Load is distributed across the members + of the pool in round-robin fashion. + + :param topic: This is a name associated with what to consume from. + Multiple instances of a service may consume from the same + topic. + :param proxy: The object that will handle all incoming messages. + :param pool_name: String containing the name of the pool of workers + """ + raise NotImplementedError() + + def join_consumer_pool(self, callback, pool_name, topic, exchange_name): + """Register as a member of a group of consumers for a given topic from + the specified exchange. + + Exactly one member of a given pool will receive each message. + + A message will be delivered to multiple pools, if more than + one is created. + + :param callback: Callable to be invoked for each message. + :type callback: callable accepting one argument + :param pool_name: The name of the consumer pool. + :type pool_name: str + :param topic: The routing topic for desired messages. + :type topic: str + :param exchange_name: The name of the message exchange where + the client should attach. Defaults to + the configured exchange. + :type exchange_name: str + """ + raise NotImplementedError() + + def consume_in_thread(self): + """Spawn a thread to handle incoming messages. + + Spawn a thread that will be responsible for handling all incoming + messages for consumers that were set up on this connection. + + Message dispatching inside of this is expected to be implemented in a + non-blocking manner. An example implementation would be having this + thread pull messages in for all of the consumers, but utilize a thread + pool for dispatching the messages to the proxy objects. + """ + raise NotImplementedError() + + +def _safe_log(log_func, msg, msg_data): + """Sanitizes the msg_data field before logging.""" + SANITIZE = {'set_admin_password': [('args', 'new_pass')], + 'run_instance': [('args', 'admin_password')], + 'route_message': [('args', 'message', 'args', 'method_info', + 'method_kwargs', 'password'), + ('args', 'message', 'args', 'method_info', + 'method_kwargs', 'admin_password')]} + + has_method = 'method' in msg_data and msg_data['method'] in SANITIZE + has_context_token = '_context_auth_token' in msg_data + has_token = 'auth_token' in msg_data + + if not any([has_method, has_context_token, has_token]): + return log_func(msg, msg_data) + + msg_data = copy.deepcopy(msg_data) + + if has_method: + for arg in SANITIZE.get(msg_data['method'], []): + try: + d = msg_data + for elem in arg[:-1]: + d = d[elem] + d[arg[-1]] = '' + except KeyError as e: + LOG.info(_('Failed to sanitize %(item)s. Key error %(err)s'), + {'item': arg, + 'err': e}) + + if has_context_token: + msg_data['_context_auth_token'] = '' + + if has_token: + msg_data['auth_token'] = '' + + return log_func(msg, msg_data) + + +def serialize_remote_exception(failure_info, log_failure=True): + """Prepares exception data to be sent over rpc. + + Failure_info should be a sys.exc_info() tuple. + + """ + tb = traceback.format_exception(*failure_info) + failure = failure_info[1] + if log_failure: + LOG.error(_("Returning exception %s to caller"), + six.text_type(failure)) + LOG.error(tb) + + kwargs = {} + if hasattr(failure, 'kwargs'): + kwargs = failure.kwargs + + data = { + 'class': str(failure.__class__.__name__), + 'module': str(failure.__class__.__module__), + 'message': six.text_type(failure), + 'tb': tb, + 'args': failure.args, + 'kwargs': kwargs + } + + json_data = jsonutils.dumps(data) + + return json_data + + +def deserialize_remote_exception(conf, data): + failure = jsonutils.loads(str(data)) + + trace = failure.get('tb', []) + message = failure.get('message', "") + "\n" + "\n".join(trace) + name = failure.get('class') + module = failure.get('module') + + # NOTE(ameade): We DO NOT want to allow just any module to be imported, in + # order to prevent arbitrary code execution. + if module not in conf.allowed_rpc_exception_modules: + return RemoteError(name, failure.get('message'), trace) + + try: + mod = importutils.import_module(module) + klass = getattr(mod, name) + if not issubclass(klass, Exception): + raise TypeError("Can only deserialize Exceptions") + + failure = klass(*failure.get('args', []), **failure.get('kwargs', {})) + except (AttributeError, TypeError, ImportError): + return RemoteError(name, failure.get('message'), trace) + + ex_type = type(failure) + str_override = lambda self: message + new_ex_type = type(ex_type.__name__ + "_Remote", (ex_type,), + {'__str__': str_override, '__unicode__': str_override}) + try: + # NOTE(ameade): Dynamically create a new exception type and swap it in + # as the new type for the exception. This only works on user defined + # Exceptions and not core python exceptions. This is important because + # we cannot necessarily change an exception message so we must override + # the __str__ method. + failure.__class__ = new_ex_type + except TypeError: + # NOTE(ameade): If a core exception then just add the traceback to the + # first exception argument. + failure.args = (message,) + failure.args[1:] + return failure + + +class CommonRpcContext(object): + def __init__(self, **kwargs): + self.values = kwargs + + def __getattr__(self, key): + try: + return self.values[key] + except KeyError: + raise AttributeError(key) + + def to_dict(self): + return copy.deepcopy(self.values) + + @classmethod + def from_dict(cls, values): + return cls(**values) + + def deepcopy(self): + return self.from_dict(self.to_dict()) + + def update_store(self): + local.store.context = self + + def elevated(self, read_deleted=None, overwrite=False): + """Return a version of this context with admin flag set.""" + # TODO(russellb) This method is a bit of a nova-ism. It makes + # some assumptions about the data in the request context sent + # across rpc, while the rest of this class does not. We could get + # rid of this if we changed the nova code that uses this to + # convert the RpcContext back to its native RequestContext doing + # something like nova.context.RequestContext.from_dict(ctxt.to_dict()) + + context = self.deepcopy() + context.values['is_admin'] = True + + context.values.setdefault('roles', []) + + if 'admin' not in context.values['roles']: + context.values['roles'].append('admin') + + if read_deleted is not None: + context.values['read_deleted'] = read_deleted + + return context + + +class ClientException(Exception): + """This encapsulates some actual exception that is expected to be + hit by an RPC proxy object. Merely instantiating it records the + current exception information, which will be passed back to the + RPC client without exceptional logging.""" + def __init__(self): + self._exc_info = sys.exc_info() + + +def catch_client_exception(exceptions, func, *args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + if type(e) in exceptions: + raise ClientException() + else: + raise + + +def client_exceptions(*exceptions): + """Decorator for manager methods that raise expected exceptions. + Marking a Manager method with this decorator allows the declaration + of expected exceptions that the RPC layer should not consider fatal, + and not log as if they were generated in a real error scenario. Note + that this will cause listed exceptions to be wrapped in a + ClientException, which is used internally by the RPC layer.""" + def outer(func): + def inner(*args, **kwargs): + return catch_client_exception(exceptions, func, *args, **kwargs) + return inner + return outer + + +def version_is_compatible(imp_version, version): + """Determine whether versions are compatible. + + :param imp_version: The version implemented + :param version: The version requested by an incoming message. + """ + version_parts = version.split('.') + imp_version_parts = imp_version.split('.') + if int(version_parts[0]) != int(imp_version_parts[0]): # Major + return False + if int(version_parts[1]) > int(imp_version_parts[1]): # Minor + return False + return True + + +def serialize_msg(raw_msg): + # NOTE(russellb) See the docstring for _RPC_ENVELOPE_VERSION for more + # information about this format. + msg = {_VERSION_KEY: _RPC_ENVELOPE_VERSION, + _MESSAGE_KEY: jsonutils.dumps(raw_msg)} + + return msg + + +def deserialize_msg(msg): + # NOTE(russellb): Hang on to your hats, this road is about to + # get a little bumpy. + # + # Robustness Principle: + # "Be strict in what you send, liberal in what you accept." + # + # At this point we have to do a bit of guessing about what it + # is we just received. Here is the set of possibilities: + # + # 1) We received a dict. This could be 2 things: + # + # a) Inspect it to see if it looks like a standard message envelope. + # If so, great! + # + # b) If it doesn't look like a standard message envelope, it could either + # be a notification, or a message from before we added a message + # envelope (referred to as version 1.0). + # Just return the message as-is. + # + # 2) It's any other non-dict type. Just return it and hope for the best. + # This case covers return values from rpc.call() from before message + # envelopes were used. (messages to call a method were always a dict) + + if not isinstance(msg, dict): + # See #2 above. + return msg + + base_envelope_keys = (_VERSION_KEY, _MESSAGE_KEY) + if not all(map(lambda key: key in msg, base_envelope_keys)): + # See #1.b above. + return msg + + # At this point we think we have the message envelope + # format we were expecting. (#1.a above) + + if not version_is_compatible(_RPC_ENVELOPE_VERSION, msg[_VERSION_KEY]): + raise UnsupportedRpcEnvelopeVersion(version=msg[_VERSION_KEY]) + + raw_msg = jsonutils.loads(msg[_MESSAGE_KEY]) + + return raw_msg diff --git a/staccato/openstack/common/rpc/dispatcher.py b/staccato/openstack/common/rpc/dispatcher.py new file mode 100644 index 0000000..65b9b5b --- /dev/null +++ b/staccato/openstack/common/rpc/dispatcher.py @@ -0,0 +1,153 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Code for rpc message dispatching. + +Messages that come in have a version number associated with them. RPC API +version numbers are in the form: + + Major.Minor + +For a given message with version X.Y, the receiver must be marked as able to +handle messages of version A.B, where: + + A = X + + B >= Y + +The Major version number would be incremented for an almost completely new API. +The Minor version number would be incremented for backwards compatible changes +to an existing API. A backwards compatible change could be something like +adding a new method, adding an argument to an existing method (but not +requiring it), or changing the type for an existing argument (but still +handling the old type as well). + +The conversion over to a versioned API must be done on both the client side and +server side of the API at the same time. However, as the code stands today, +there can be both versioned and unversioned APIs implemented in the same code +base. + +EXAMPLES +======== + +Nova was the first project to use versioned rpc APIs. Consider the compute rpc +API as an example. The client side is in nova/compute/rpcapi.py and the server +side is in nova/compute/manager.py. + + +Example 1) Adding a new method. +------------------------------- + +Adding a new method is a backwards compatible change. It should be added to +nova/compute/manager.py, and RPC_API_VERSION should be bumped from X.Y to +X.Y+1. On the client side, the new method in nova/compute/rpcapi.py should +have a specific version specified to indicate the minimum API version that must +be implemented for the method to be supported. For example:: + + def get_host_uptime(self, ctxt, host): + topic = _compute_topic(self.topic, ctxt, host, None) + return self.call(ctxt, self.make_msg('get_host_uptime'), topic, + version='1.1') + +In this case, version '1.1' is the first version that supported the +get_host_uptime() method. + + +Example 2) Adding a new parameter. +---------------------------------- + +Adding a new parameter to an rpc method can be made backwards compatible. The +RPC_API_VERSION on the server side (nova/compute/manager.py) should be bumped. +The implementation of the method must not expect the parameter to be present.:: + + def some_remote_method(self, arg1, arg2, newarg=None): + # The code needs to deal with newarg=None for cases + # where an older client sends a message without it. + pass + +On the client side, the same changes should be made as in example 1. The +minimum version that supports the new parameter should be specified. +""" + +from staccato.openstack.common.rpc import common as rpc_common + + +class RpcDispatcher(object): + """Dispatch rpc messages according to the requested API version. + + This class can be used as the top level 'manager' for a service. It + contains a list of underlying managers that have an API_VERSION attribute. + """ + + def __init__(self, callbacks): + """Initialize the rpc dispatcher. + + :param callbacks: List of proxy objects that are an instance + of a class with rpc methods exposed. Each proxy + object should have an RPC_API_VERSION attribute. + """ + self.callbacks = callbacks + super(RpcDispatcher, self).__init__() + + def dispatch(self, ctxt, version, method, namespace, **kwargs): + """Dispatch a message based on a requested version. + + :param ctxt: The request context + :param version: The requested API version from the incoming message + :param method: The method requested to be called by the incoming + message. + :param namespace: The namespace for the requested method. If None, + the dispatcher will look for a method on a callback + object with no namespace set. + :param kwargs: A dict of keyword arguments to be passed to the method. + + :returns: Whatever is returned by the underlying method that gets + called. + """ + if not version: + version = '1.0' + + had_compatible = False + for proxyobj in self.callbacks: + # Check for namespace compatibility + try: + cb_namespace = proxyobj.RPC_API_NAMESPACE + except AttributeError: + cb_namespace = None + + if namespace != cb_namespace: + continue + + # Check for version compatibility + try: + rpc_api_version = proxyobj.RPC_API_VERSION + except AttributeError: + rpc_api_version = '1.0' + + is_compatible = rpc_common.version_is_compatible(rpc_api_version, + version) + had_compatible = had_compatible or is_compatible + + if not hasattr(proxyobj, method): + continue + if is_compatible: + return getattr(proxyobj, method)(ctxt, **kwargs) + + if had_compatible: + raise AttributeError("No such RPC function '%s'" % method) + else: + raise rpc_common.UnsupportedRpcVersion(version=version) diff --git a/staccato/openstack/common/rpc/impl_fake.py b/staccato/openstack/common/rpc/impl_fake.py new file mode 100644 index 0000000..019d046 --- /dev/null +++ b/staccato/openstack/common/rpc/impl_fake.py @@ -0,0 +1,195 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +"""Fake RPC implementation which calls proxy methods directly with no +queues. Casts will block, but this is very useful for tests. +""" + +import inspect +# NOTE(russellb): We specifically want to use json, not our own jsonutils. +# jsonutils has some extra logic to automatically convert objects to primitive +# types so that they can be serialized. We want to catch all cases where +# non-primitive types make it into this code and treat it as an error. +import json +import time + +import eventlet + +from staccato.openstack.common.rpc import common as rpc_common + +CONSUMERS = {} + + +class RpcContext(rpc_common.CommonRpcContext): + def __init__(self, **kwargs): + super(RpcContext, self).__init__(**kwargs) + self._response = [] + self._done = False + + def deepcopy(self): + values = self.to_dict() + new_inst = self.__class__(**values) + new_inst._response = self._response + new_inst._done = self._done + return new_inst + + def reply(self, reply=None, failure=None, ending=False): + if ending: + self._done = True + if not self._done: + self._response.append((reply, failure)) + + +class Consumer(object): + def __init__(self, topic, proxy): + self.topic = topic + self.proxy = proxy + + def call(self, context, version, method, namespace, args, timeout): + done = eventlet.event.Event() + + def _inner(): + ctxt = RpcContext.from_dict(context.to_dict()) + try: + rval = self.proxy.dispatch(context, version, method, + namespace, **args) + res = [] + # Caller might have called ctxt.reply() manually + for (reply, failure) in ctxt._response: + if failure: + raise failure[0], failure[1], failure[2] + res.append(reply) + # if ending not 'sent'...we might have more data to + # return from the function itself + if not ctxt._done: + if inspect.isgenerator(rval): + for val in rval: + res.append(val) + else: + res.append(rval) + done.send(res) + except rpc_common.ClientException as e: + done.send_exception(e._exc_info[1]) + except Exception as e: + done.send_exception(e) + + thread = eventlet.greenthread.spawn(_inner) + + if timeout: + start_time = time.time() + while not done.ready(): + eventlet.greenthread.sleep(1) + cur_time = time.time() + if (cur_time - start_time) > timeout: + thread.kill() + raise rpc_common.Timeout() + + return done.wait() + + +class Connection(object): + """Connection object.""" + + def __init__(self): + self.consumers = [] + + def create_consumer(self, topic, proxy, fanout=False): + consumer = Consumer(topic, proxy) + self.consumers.append(consumer) + if topic not in CONSUMERS: + CONSUMERS[topic] = [] + CONSUMERS[topic].append(consumer) + + def close(self): + for consumer in self.consumers: + CONSUMERS[consumer.topic].remove(consumer) + self.consumers = [] + + def consume_in_thread(self): + pass + + +def create_connection(conf, new=True): + """Create a connection""" + return Connection() + + +def check_serialize(msg): + """Make sure a message intended for rpc can be serialized.""" + json.dumps(msg) + + +def multicall(conf, context, topic, msg, timeout=None): + """Make a call that returns multiple times.""" + + check_serialize(msg) + + method = msg.get('method') + if not method: + return + args = msg.get('args', {}) + version = msg.get('version', None) + namespace = msg.get('namespace', None) + + try: + consumer = CONSUMERS[topic][0] + except (KeyError, IndexError): + return iter([None]) + else: + return consumer.call(context, version, method, namespace, args, + timeout) + + +def call(conf, context, topic, msg, timeout=None): + """Sends a message on a topic and wait for a response.""" + rv = multicall(conf, context, topic, msg, timeout) + # NOTE(vish): return the last result from the multicall + rv = list(rv) + if not rv: + return + return rv[-1] + + +def cast(conf, context, topic, msg): + check_serialize(msg) + try: + call(conf, context, topic, msg) + except Exception: + pass + + +def notify(conf, context, topic, msg, envelope): + check_serialize(msg) + + +def cleanup(): + pass + + +def fanout_cast(conf, context, topic, msg): + """Cast to all consumers of a topic""" + check_serialize(msg) + method = msg.get('method') + if not method: + return + args = msg.get('args', {}) + version = msg.get('version', None) + namespace = msg.get('namespace', None) + + for consumer in CONSUMERS.get(topic, []): + try: + consumer.call(context, version, method, namespace, args, None) + except Exception: + pass diff --git a/staccato/openstack/common/rpc/impl_kombu.py b/staccato/openstack/common/rpc/impl_kombu.py new file mode 100644 index 0000000..7901a0a --- /dev/null +++ b/staccato/openstack/common/rpc/impl_kombu.py @@ -0,0 +1,838 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import functools +import itertools +import socket +import ssl +import sys +import time +import uuid + +import eventlet +import greenlet +import kombu +import kombu.connection +import kombu.entity +import kombu.messaging +from oslo.config import cfg + +from staccato.openstack.common.gettextutils import _ +from staccato.openstack.common import network_utils +from staccato.openstack.common.rpc import amqp as rpc_amqp +from staccato.openstack.common.rpc import common as rpc_common + +kombu_opts = [ + cfg.StrOpt('kombu_ssl_version', + default='', + help='SSL version to use (valid only if SSL enabled)'), + cfg.StrOpt('kombu_ssl_keyfile', + default='', + help='SSL key file (valid only if SSL enabled)'), + cfg.StrOpt('kombu_ssl_certfile', + default='', + help='SSL cert file (valid only if SSL enabled)'), + cfg.StrOpt('kombu_ssl_ca_certs', + default='', + help=('SSL certification authority file ' + '(valid only if SSL enabled)')), + cfg.StrOpt('rabbit_host', + default='localhost', + help='The RabbitMQ broker address where a single node is used'), + cfg.IntOpt('rabbit_port', + default=5672, + help='The RabbitMQ broker port where a single node is used'), + cfg.ListOpt('rabbit_hosts', + default=['$rabbit_host:$rabbit_port'], + help='RabbitMQ HA cluster host:port pairs'), + cfg.BoolOpt('rabbit_use_ssl', + default=False, + help='connect over SSL for RabbitMQ'), + cfg.StrOpt('rabbit_userid', + default='guest', + help='the RabbitMQ userid'), + cfg.StrOpt('rabbit_password', + default='guest', + help='the RabbitMQ password', + secret=True), + cfg.StrOpt('rabbit_virtual_host', + default='/', + help='the RabbitMQ virtual host'), + cfg.IntOpt('rabbit_retry_interval', + default=1, + help='how frequently to retry connecting with RabbitMQ'), + cfg.IntOpt('rabbit_retry_backoff', + default=2, + help='how long to backoff for between retries when connecting ' + 'to RabbitMQ'), + cfg.IntOpt('rabbit_max_retries', + default=0, + help='maximum retries with trying to connect to RabbitMQ ' + '(the default of 0 implies an infinite retry count)'), + cfg.BoolOpt('rabbit_durable_queues', + default=False, + help='use durable queues in RabbitMQ'), + cfg.BoolOpt('rabbit_ha_queues', + default=False, + help='use H/A queues in RabbitMQ (x-ha-policy: all).' + 'You need to wipe RabbitMQ database when ' + 'changing this option.'), + +] + +cfg.CONF.register_opts(kombu_opts) + +LOG = rpc_common.LOG + + +def _get_queue_arguments(conf): + """Construct the arguments for declaring a queue. + + If the rabbit_ha_queues option is set, we declare a mirrored queue + as described here: + + http://www.rabbitmq.com/ha.html + + Setting x-ha-policy to all means that the queue will be mirrored + to all nodes in the cluster. + """ + return {'x-ha-policy': 'all'} if conf.rabbit_ha_queues else {} + + +class ConsumerBase(object): + """Consumer base class.""" + + def __init__(self, channel, callback, tag, **kwargs): + """Declare a queue on an amqp channel. + + 'channel' is the amqp channel to use + 'callback' is the callback to call when messages are received + 'tag' is a unique ID for the consumer on the channel + + queue name, exchange name, and other kombu options are + passed in here as a dictionary. + """ + self.callback = callback + self.tag = str(tag) + self.kwargs = kwargs + self.queue = None + self.reconnect(channel) + + def reconnect(self, channel): + """Re-declare the queue after a rabbit reconnect""" + self.channel = channel + self.kwargs['channel'] = channel + self.queue = kombu.entity.Queue(**self.kwargs) + self.queue.declare() + + def consume(self, *args, **kwargs): + """Actually declare the consumer on the amqp channel. This will + start the flow of messages from the queue. Using the + Connection.iterconsume() iterator will process the messages, + calling the appropriate callback. + + If a callback is specified in kwargs, use that. Otherwise, + use the callback passed during __init__() + + If kwargs['nowait'] is True, then this call will block until + a message is read. + + Messages will automatically be acked if the callback doesn't + raise an exception + """ + + options = {'consumer_tag': self.tag} + options['nowait'] = kwargs.get('nowait', False) + callback = kwargs.get('callback', self.callback) + if not callback: + raise ValueError("No callback defined") + + def _callback(raw_message): + message = self.channel.message_to_python(raw_message) + try: + msg = rpc_common.deserialize_msg(message.payload) + callback(msg) + except Exception: + LOG.exception(_("Failed to process message... skipping it.")) + finally: + message.ack() + + self.queue.consume(*args, callback=_callback, **options) + + def cancel(self): + """Cancel the consuming from the queue, if it has started""" + try: + self.queue.cancel(self.tag) + except KeyError as e: + # NOTE(comstud): Kludge to get around a amqplib bug + if str(e) != "u'%s'" % self.tag: + raise + self.queue = None + + +class DirectConsumer(ConsumerBase): + """Queue/consumer class for 'direct'""" + + def __init__(self, conf, channel, msg_id, callback, tag, **kwargs): + """Init a 'direct' queue. + + 'channel' is the amqp channel to use + 'msg_id' is the msg_id to listen on + 'callback' is the callback to call when messages are received + 'tag' is a unique ID for the consumer on the channel + + Other kombu options may be passed + """ + # Default options + options = {'durable': False, + 'queue_arguments': _get_queue_arguments(conf), + 'auto_delete': True, + 'exclusive': False} + options.update(kwargs) + exchange = kombu.entity.Exchange(name=msg_id, + type='direct', + durable=options['durable'], + auto_delete=options['auto_delete']) + super(DirectConsumer, self).__init__(channel, + callback, + tag, + name=msg_id, + exchange=exchange, + routing_key=msg_id, + **options) + + +class TopicConsumer(ConsumerBase): + """Consumer class for 'topic'""" + + def __init__(self, conf, channel, topic, callback, tag, name=None, + exchange_name=None, **kwargs): + """Init a 'topic' queue. + + :param channel: the amqp channel to use + :param topic: the topic to listen on + :paramtype topic: str + :param callback: the callback to call when messages are received + :param tag: a unique ID for the consumer on the channel + :param name: optional queue name, defaults to topic + :paramtype name: str + + Other kombu options may be passed as keyword arguments + """ + # Default options + options = {'durable': conf.rabbit_durable_queues, + 'queue_arguments': _get_queue_arguments(conf), + 'auto_delete': False, + 'exclusive': False} + options.update(kwargs) + exchange_name = exchange_name or rpc_amqp.get_control_exchange(conf) + exchange = kombu.entity.Exchange(name=exchange_name, + type='topic', + durable=options['durable'], + auto_delete=options['auto_delete']) + super(TopicConsumer, self).__init__(channel, + callback, + tag, + name=name or topic, + exchange=exchange, + routing_key=topic, + **options) + + +class FanoutConsumer(ConsumerBase): + """Consumer class for 'fanout'""" + + def __init__(self, conf, channel, topic, callback, tag, **kwargs): + """Init a 'fanout' queue. + + 'channel' is the amqp channel to use + 'topic' is the topic to listen on + 'callback' is the callback to call when messages are received + 'tag' is a unique ID for the consumer on the channel + + Other kombu options may be passed + """ + unique = uuid.uuid4().hex + exchange_name = '%s_fanout' % topic + queue_name = '%s_fanout_%s' % (topic, unique) + + # Default options + options = {'durable': False, + 'queue_arguments': _get_queue_arguments(conf), + 'auto_delete': True, + 'exclusive': False} + options.update(kwargs) + exchange = kombu.entity.Exchange(name=exchange_name, type='fanout', + durable=options['durable'], + auto_delete=options['auto_delete']) + super(FanoutConsumer, self).__init__(channel, callback, tag, + name=queue_name, + exchange=exchange, + routing_key=topic, + **options) + + +class Publisher(object): + """Base Publisher class""" + + def __init__(self, channel, exchange_name, routing_key, **kwargs): + """Init the Publisher class with the exchange_name, routing_key, + and other options + """ + self.exchange_name = exchange_name + self.routing_key = routing_key + self.kwargs = kwargs + self.reconnect(channel) + + def reconnect(self, channel): + """Re-establish the Producer after a rabbit reconnection""" + self.exchange = kombu.entity.Exchange(name=self.exchange_name, + **self.kwargs) + self.producer = kombu.messaging.Producer(exchange=self.exchange, + channel=channel, + routing_key=self.routing_key) + + def send(self, msg, timeout=None): + """Send a message""" + if timeout: + # + # AMQP TTL is in milliseconds when set in the header. + # + self.producer.publish(msg, headers={'ttl': (timeout * 1000)}) + else: + self.producer.publish(msg) + + +class DirectPublisher(Publisher): + """Publisher class for 'direct'""" + def __init__(self, conf, channel, msg_id, **kwargs): + """init a 'direct' publisher. + + Kombu options may be passed as keyword args to override defaults + """ + + options = {'durable': False, + 'auto_delete': True, + 'exclusive': False} + options.update(kwargs) + super(DirectPublisher, self).__init__(channel, msg_id, msg_id, + type='direct', **options) + + +class TopicPublisher(Publisher): + """Publisher class for 'topic'""" + def __init__(self, conf, channel, topic, **kwargs): + """init a 'topic' publisher. + + Kombu options may be passed as keyword args to override defaults + """ + options = {'durable': conf.rabbit_durable_queues, + 'auto_delete': False, + 'exclusive': False} + options.update(kwargs) + exchange_name = rpc_amqp.get_control_exchange(conf) + super(TopicPublisher, self).__init__(channel, + exchange_name, + topic, + type='topic', + **options) + + +class FanoutPublisher(Publisher): + """Publisher class for 'fanout'""" + def __init__(self, conf, channel, topic, **kwargs): + """init a 'fanout' publisher. + + Kombu options may be passed as keyword args to override defaults + """ + options = {'durable': False, + 'auto_delete': True, + 'exclusive': False} + options.update(kwargs) + super(FanoutPublisher, self).__init__(channel, '%s_fanout' % topic, + None, type='fanout', **options) + + +class NotifyPublisher(TopicPublisher): + """Publisher class for 'notify'""" + + def __init__(self, conf, channel, topic, **kwargs): + self.durable = kwargs.pop('durable', conf.rabbit_durable_queues) + self.queue_arguments = _get_queue_arguments(conf) + super(NotifyPublisher, self).__init__(conf, channel, topic, **kwargs) + + def reconnect(self, channel): + super(NotifyPublisher, self).reconnect(channel) + + # NOTE(jerdfelt): Normally the consumer would create the queue, but + # we do this to ensure that messages don't get dropped if the + # consumer is started after we do + queue = kombu.entity.Queue(channel=channel, + exchange=self.exchange, + durable=self.durable, + name=self.routing_key, + routing_key=self.routing_key, + queue_arguments=self.queue_arguments) + queue.declare() + + +class Connection(object): + """Connection object.""" + + pool = None + + def __init__(self, conf, server_params=None): + self.consumers = [] + self.consumer_thread = None + self.proxy_callbacks = [] + self.conf = conf + self.max_retries = self.conf.rabbit_max_retries + # Try forever? + if self.max_retries <= 0: + self.max_retries = None + self.interval_start = self.conf.rabbit_retry_interval + self.interval_stepping = self.conf.rabbit_retry_backoff + # max retry-interval = 30 seconds + self.interval_max = 30 + self.memory_transport = False + + if server_params is None: + server_params = {} + # Keys to translate from server_params to kombu params + server_params_to_kombu_params = {'username': 'userid'} + + ssl_params = self._fetch_ssl_params() + params_list = [] + for adr in self.conf.rabbit_hosts: + hostname, port = network_utils.parse_host_port( + adr, default_port=self.conf.rabbit_port) + + params = { + 'hostname': hostname, + 'port': port, + 'userid': self.conf.rabbit_userid, + 'password': self.conf.rabbit_password, + 'virtual_host': self.conf.rabbit_virtual_host, + } + + for sp_key, value in server_params.iteritems(): + p_key = server_params_to_kombu_params.get(sp_key, sp_key) + params[p_key] = value + + if self.conf.fake_rabbit: + params['transport'] = 'memory' + if self.conf.rabbit_use_ssl: + params['ssl'] = ssl_params + + params_list.append(params) + + self.params_list = params_list + + self.memory_transport = self.conf.fake_rabbit + + self.connection = None + self.reconnect() + + def _fetch_ssl_params(self): + """Handles fetching what ssl params + should be used for the connection (if any)""" + ssl_params = dict() + + # http://docs.python.org/library/ssl.html - ssl.wrap_socket + if self.conf.kombu_ssl_version: + ssl_params['ssl_version'] = self.conf.kombu_ssl_version + if self.conf.kombu_ssl_keyfile: + ssl_params['keyfile'] = self.conf.kombu_ssl_keyfile + if self.conf.kombu_ssl_certfile: + ssl_params['certfile'] = self.conf.kombu_ssl_certfile + if self.conf.kombu_ssl_ca_certs: + ssl_params['ca_certs'] = self.conf.kombu_ssl_ca_certs + # We might want to allow variations in the + # future with this? + ssl_params['cert_reqs'] = ssl.CERT_REQUIRED + + if not ssl_params: + # Just have the default behavior + return True + else: + # Return the extended behavior + return ssl_params + + def _connect(self, params): + """Connect to rabbit. Re-establish any queues that may have + been declared before if we are reconnecting. Exceptions should + be handled by the caller. + """ + if self.connection: + LOG.info(_("Reconnecting to AMQP server on " + "%(hostname)s:%(port)d") % params) + try: + self.connection.release() + except self.connection_errors: + pass + # Setting this in case the next statement fails, though + # it shouldn't be doing any network operations, yet. + self.connection = None + self.connection = kombu.connection.BrokerConnection(**params) + self.connection_errors = self.connection.connection_errors + if self.memory_transport: + # Kludge to speed up tests. + self.connection.transport.polling_interval = 0.0 + self.consumer_num = itertools.count(1) + self.connection.connect() + self.channel = self.connection.channel() + # work around 'memory' transport bug in 1.1.3 + if self.memory_transport: + self.channel._new_queue('ae.undeliver') + for consumer in self.consumers: + consumer.reconnect(self.channel) + LOG.info(_('Connected to AMQP server on %(hostname)s:%(port)d') % + params) + + def reconnect(self): + """Handles reconnecting and re-establishing queues. + Will retry up to self.max_retries number of times. + self.max_retries = 0 means to retry forever. + Sleep between tries, starting at self.interval_start + seconds, backing off self.interval_stepping number of seconds + each attempt. + """ + + attempt = 0 + while True: + params = self.params_list[attempt % len(self.params_list)] + attempt += 1 + try: + self._connect(params) + return + except (IOError, self.connection_errors) as e: + pass + except Exception as e: + # NOTE(comstud): Unfortunately it's possible for amqplib + # to return an error not covered by its transport + # connection_errors in the case of a timeout waiting for + # a protocol response. (See paste link in LP888621) + # So, we check all exceptions for 'timeout' in them + # and try to reconnect in this case. + if 'timeout' not in str(e): + raise + + log_info = {} + log_info['err_str'] = str(e) + log_info['max_retries'] = self.max_retries + log_info.update(params) + + if self.max_retries and attempt == self.max_retries: + LOG.error(_('Unable to connect to AMQP server on ' + '%(hostname)s:%(port)d after %(max_retries)d ' + 'tries: %(err_str)s') % log_info) + # NOTE(comstud): Copied from original code. There's + # really no better recourse because if this was a queue we + # need to consume on, we have no way to consume anymore. + sys.exit(1) + + if attempt == 1: + sleep_time = self.interval_start or 1 + elif attempt > 1: + sleep_time += self.interval_stepping + if self.interval_max: + sleep_time = min(sleep_time, self.interval_max) + + log_info['sleep_time'] = sleep_time + LOG.error(_('AMQP server on %(hostname)s:%(port)d is ' + 'unreachable: %(err_str)s. Trying again in ' + '%(sleep_time)d seconds.') % log_info) + time.sleep(sleep_time) + + def ensure(self, error_callback, method, *args, **kwargs): + while True: + try: + return method(*args, **kwargs) + except (self.connection_errors, socket.timeout, IOError) as e: + if error_callback: + error_callback(e) + except Exception as e: + # NOTE(comstud): Unfortunately it's possible for amqplib + # to return an error not covered by its transport + # connection_errors in the case of a timeout waiting for + # a protocol response. (See paste link in LP888621) + # So, we check all exceptions for 'timeout' in them + # and try to reconnect in this case. + if 'timeout' not in str(e): + raise + if error_callback: + error_callback(e) + self.reconnect() + + def get_channel(self): + """Convenience call for bin/clear_rabbit_queues""" + return self.channel + + def close(self): + """Close/release this connection""" + self.cancel_consumer_thread() + self.wait_on_proxy_callbacks() + self.connection.release() + self.connection = None + + def reset(self): + """Reset a connection so it can be used again""" + self.cancel_consumer_thread() + self.wait_on_proxy_callbacks() + self.channel.close() + self.channel = self.connection.channel() + # work around 'memory' transport bug in 1.1.3 + if self.memory_transport: + self.channel._new_queue('ae.undeliver') + self.consumers = [] + + def declare_consumer(self, consumer_cls, topic, callback): + """Create a Consumer using the class that was passed in and + add it to our list of consumers + """ + + def _connect_error(exc): + log_info = {'topic': topic, 'err_str': str(exc)} + LOG.error(_("Failed to declare consumer for topic '%(topic)s': " + "%(err_str)s") % log_info) + + def _declare_consumer(): + consumer = consumer_cls(self.conf, self.channel, topic, callback, + self.consumer_num.next()) + self.consumers.append(consumer) + return consumer + + return self.ensure(_connect_error, _declare_consumer) + + def iterconsume(self, limit=None, timeout=None): + """Return an iterator that will consume from all queues/consumers""" + + info = {'do_consume': True} + + def _error_callback(exc): + if isinstance(exc, socket.timeout): + LOG.debug(_('Timed out waiting for RPC response: %s') % + str(exc)) + raise rpc_common.Timeout() + else: + LOG.exception(_('Failed to consume message from queue: %s') % + str(exc)) + info['do_consume'] = True + + def _consume(): + if info['do_consume']: + queues_head = self.consumers[:-1] + queues_tail = self.consumers[-1] + for queue in queues_head: + queue.consume(nowait=True) + queues_tail.consume(nowait=False) + info['do_consume'] = False + return self.connection.drain_events(timeout=timeout) + + for iteration in itertools.count(0): + if limit and iteration >= limit: + raise StopIteration + yield self.ensure(_error_callback, _consume) + + def cancel_consumer_thread(self): + """Cancel a consumer thread""" + if self.consumer_thread is not None: + self.consumer_thread.kill() + try: + self.consumer_thread.wait() + except greenlet.GreenletExit: + pass + self.consumer_thread = None + + def wait_on_proxy_callbacks(self): + """Wait for all proxy callback threads to exit.""" + for proxy_cb in self.proxy_callbacks: + proxy_cb.wait() + + def publisher_send(self, cls, topic, msg, timeout=None, **kwargs): + """Send to a publisher based on the publisher class""" + + def _error_callback(exc): + log_info = {'topic': topic, 'err_str': str(exc)} + LOG.exception(_("Failed to publish message to topic " + "'%(topic)s': %(err_str)s") % log_info) + + def _publish(): + publisher = cls(self.conf, self.channel, topic, **kwargs) + publisher.send(msg, timeout) + + self.ensure(_error_callback, _publish) + + def declare_direct_consumer(self, topic, callback): + """Create a 'direct' queue. + In nova's use, this is generally a msg_id queue used for + responses for call/multicall + """ + self.declare_consumer(DirectConsumer, topic, callback) + + def declare_topic_consumer(self, topic, callback=None, queue_name=None, + exchange_name=None): + """Create a 'topic' consumer.""" + self.declare_consumer(functools.partial(TopicConsumer, + name=queue_name, + exchange_name=exchange_name, + ), + topic, callback) + + def declare_fanout_consumer(self, topic, callback): + """Create a 'fanout' consumer""" + self.declare_consumer(FanoutConsumer, topic, callback) + + def direct_send(self, msg_id, msg): + """Send a 'direct' message""" + self.publisher_send(DirectPublisher, msg_id, msg) + + def topic_send(self, topic, msg, timeout=None): + """Send a 'topic' message""" + self.publisher_send(TopicPublisher, topic, msg, timeout) + + def fanout_send(self, topic, msg): + """Send a 'fanout' message""" + self.publisher_send(FanoutPublisher, topic, msg) + + def notify_send(self, topic, msg, **kwargs): + """Send a notify message on a topic""" + self.publisher_send(NotifyPublisher, topic, msg, None, **kwargs) + + def consume(self, limit=None): + """Consume from all queues/consumers""" + it = self.iterconsume(limit=limit) + while True: + try: + it.next() + except StopIteration: + return + + def consume_in_thread(self): + """Consumer from all queues/consumers in a greenthread""" + def _consumer_thread(): + try: + self.consume() + except greenlet.GreenletExit: + return + if self.consumer_thread is None: + self.consumer_thread = eventlet.spawn(_consumer_thread) + return self.consumer_thread + + def create_consumer(self, topic, proxy, fanout=False): + """Create a consumer that calls a method in a proxy object""" + proxy_cb = rpc_amqp.ProxyCallback( + self.conf, proxy, + rpc_amqp.get_connection_pool(self.conf, Connection)) + self.proxy_callbacks.append(proxy_cb) + + if fanout: + self.declare_fanout_consumer(topic, proxy_cb) + else: + self.declare_topic_consumer(topic, proxy_cb) + + def create_worker(self, topic, proxy, pool_name): + """Create a worker that calls a method in a proxy object""" + proxy_cb = rpc_amqp.ProxyCallback( + self.conf, proxy, + rpc_amqp.get_connection_pool(self.conf, Connection)) + self.proxy_callbacks.append(proxy_cb) + self.declare_topic_consumer(topic, proxy_cb, pool_name) + + def join_consumer_pool(self, callback, pool_name, topic, + exchange_name=None): + """Register as a member of a group of consumers for a given topic from + the specified exchange. + + Exactly one member of a given pool will receive each message. + + A message will be delivered to multiple pools, if more than + one is created. + """ + callback_wrapper = rpc_amqp.CallbackWrapper( + conf=self.conf, + callback=callback, + connection_pool=rpc_amqp.get_connection_pool(self.conf, + Connection), + ) + self.proxy_callbacks.append(callback_wrapper) + self.declare_topic_consumer( + queue_name=pool_name, + topic=topic, + exchange_name=exchange_name, + callback=callback_wrapper, + ) + + +def create_connection(conf, new=True): + """Create a connection""" + return rpc_amqp.create_connection( + conf, new, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def multicall(conf, context, topic, msg, timeout=None): + """Make a call that returns multiple times.""" + return rpc_amqp.multicall( + conf, context, topic, msg, timeout, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def call(conf, context, topic, msg, timeout=None): + """Sends a message on a topic and wait for a response.""" + return rpc_amqp.call( + conf, context, topic, msg, timeout, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def cast(conf, context, topic, msg): + """Sends a message on a topic without waiting for a response.""" + return rpc_amqp.cast( + conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def fanout_cast(conf, context, topic, msg): + """Sends a message on a fanout exchange without waiting for a response.""" + return rpc_amqp.fanout_cast( + conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def cast_to_server(conf, context, server_params, topic, msg): + """Sends a message on a topic to a specific server.""" + return rpc_amqp.cast_to_server( + conf, context, server_params, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def fanout_cast_to_server(conf, context, server_params, topic, msg): + """Sends a message on a fanout exchange to a specific server.""" + return rpc_amqp.fanout_cast_to_server( + conf, context, server_params, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def notify(conf, context, topic, msg, envelope): + """Sends a notification event on a topic.""" + return rpc_amqp.notify( + conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection), + envelope) + + +def cleanup(): + return rpc_amqp.cleanup(Connection.pool) diff --git a/staccato/openstack/common/rpc/impl_qpid.py b/staccato/openstack/common/rpc/impl_qpid.py new file mode 100644 index 0000000..8a9a1be --- /dev/null +++ b/staccato/openstack/common/rpc/impl_qpid.py @@ -0,0 +1,650 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation +# Copyright 2011 - 2012, Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import functools +import itertools +import time +import uuid + +import eventlet +import greenlet +from oslo.config import cfg + +from staccato.openstack.common.gettextutils import _ +from staccato.openstack.common import importutils +from staccato.openstack.common import jsonutils +from staccato.openstack.common import log as logging +from staccato.openstack.common.rpc import amqp as rpc_amqp +from staccato.openstack.common.rpc import common as rpc_common + +qpid_messaging = importutils.try_import("qpid.messaging") +qpid_exceptions = importutils.try_import("qpid.messaging.exceptions") + +LOG = logging.getLogger(__name__) + +qpid_opts = [ + cfg.StrOpt('qpid_hostname', + default='localhost', + help='Qpid broker hostname'), + cfg.IntOpt('qpid_port', + default=5672, + help='Qpid broker port'), + cfg.ListOpt('qpid_hosts', + default=['$qpid_hostname:$qpid_port'], + help='Qpid HA cluster host:port pairs'), + cfg.StrOpt('qpid_username', + default='', + help='Username for qpid connection'), + cfg.StrOpt('qpid_password', + default='', + help='Password for qpid connection', + secret=True), + cfg.StrOpt('qpid_sasl_mechanisms', + default='', + help='Space separated list of SASL mechanisms to use for auth'), + cfg.IntOpt('qpid_heartbeat', + default=60, + help='Seconds between connection keepalive heartbeats'), + cfg.StrOpt('qpid_protocol', + default='tcp', + help="Transport to use, either 'tcp' or 'ssl'"), + cfg.BoolOpt('qpid_tcp_nodelay', + default=True, + help='Disable Nagle algorithm'), +] + +cfg.CONF.register_opts(qpid_opts) + + +class ConsumerBase(object): + """Consumer base class.""" + + def __init__(self, session, callback, node_name, node_opts, + link_name, link_opts): + """Declare a queue on an amqp session. + + 'session' is the amqp session to use + 'callback' is the callback to call when messages are received + 'node_name' is the first part of the Qpid address string, before ';' + 'node_opts' will be applied to the "x-declare" section of "node" + in the address string. + 'link_name' goes into the "name" field of the "link" in the address + string + 'link_opts' will be applied to the "x-declare" section of "link" + in the address string. + """ + self.callback = callback + self.receiver = None + self.session = None + + addr_opts = { + "create": "always", + "node": { + "type": "topic", + "x-declare": { + "durable": True, + "auto-delete": True, + }, + }, + "link": { + "name": link_name, + "durable": True, + "x-declare": { + "durable": False, + "auto-delete": True, + "exclusive": False, + }, + }, + } + addr_opts["node"]["x-declare"].update(node_opts) + addr_opts["link"]["x-declare"].update(link_opts) + + self.address = "%s ; %s" % (node_name, jsonutils.dumps(addr_opts)) + + self.reconnect(session) + + def reconnect(self, session): + """Re-declare the receiver after a qpid reconnect""" + self.session = session + self.receiver = session.receiver(self.address) + self.receiver.capacity = 1 + + def consume(self): + """Fetch the message and pass it to the callback object""" + message = self.receiver.fetch() + try: + msg = rpc_common.deserialize_msg(message.content) + self.callback(msg) + except Exception: + LOG.exception(_("Failed to process message... skipping it.")) + finally: + self.session.acknowledge(message) + + def get_receiver(self): + return self.receiver + + +class DirectConsumer(ConsumerBase): + """Queue/consumer class for 'direct'""" + + def __init__(self, conf, session, msg_id, callback): + """Init a 'direct' queue. + + 'session' is the amqp session to use + 'msg_id' is the msg_id to listen on + 'callback' is the callback to call when messages are received + """ + + super(DirectConsumer, self).__init__(session, callback, + "%s/%s" % (msg_id, msg_id), + {"type": "direct"}, + msg_id, + {"exclusive": True}) + + +class TopicConsumer(ConsumerBase): + """Consumer class for 'topic'""" + + def __init__(self, conf, session, topic, callback, name=None, + exchange_name=None): + """Init a 'topic' queue. + + :param session: the amqp session to use + :param topic: is the topic to listen on + :paramtype topic: str + :param callback: the callback to call when messages are received + :param name: optional queue name, defaults to topic + """ + + exchange_name = exchange_name or rpc_amqp.get_control_exchange(conf) + super(TopicConsumer, self).__init__(session, callback, + "%s/%s" % (exchange_name, topic), + {}, name or topic, {}) + + +class FanoutConsumer(ConsumerBase): + """Consumer class for 'fanout'""" + + def __init__(self, conf, session, topic, callback): + """Init a 'fanout' queue. + + 'session' is the amqp session to use + 'topic' is the topic to listen on + 'callback' is the callback to call when messages are received + """ + + super(FanoutConsumer, self).__init__( + session, callback, + "%s_fanout" % topic, + {"durable": False, "type": "fanout"}, + "%s_fanout_%s" % (topic, uuid.uuid4().hex), + {"exclusive": True}) + + +class Publisher(object): + """Base Publisher class""" + + def __init__(self, session, node_name, node_opts=None): + """Init the Publisher class with the exchange_name, routing_key, + and other options + """ + self.sender = None + self.session = session + + addr_opts = { + "create": "always", + "node": { + "type": "topic", + "x-declare": { + "durable": False, + # auto-delete isn't implemented for exchanges in qpid, + # but put in here anyway + "auto-delete": True, + }, + }, + } + if node_opts: + addr_opts["node"]["x-declare"].update(node_opts) + + self.address = "%s ; %s" % (node_name, jsonutils.dumps(addr_opts)) + + self.reconnect(session) + + def reconnect(self, session): + """Re-establish the Sender after a reconnection""" + self.sender = session.sender(self.address) + + def send(self, msg): + """Send a message""" + self.sender.send(msg) + + +class DirectPublisher(Publisher): + """Publisher class for 'direct'""" + def __init__(self, conf, session, msg_id): + """Init a 'direct' publisher.""" + super(DirectPublisher, self).__init__(session, msg_id, + {"type": "Direct"}) + + +class TopicPublisher(Publisher): + """Publisher class for 'topic'""" + def __init__(self, conf, session, topic): + """init a 'topic' publisher. + """ + exchange_name = rpc_amqp.get_control_exchange(conf) + super(TopicPublisher, self).__init__(session, + "%s/%s" % (exchange_name, topic)) + + +class FanoutPublisher(Publisher): + """Publisher class for 'fanout'""" + def __init__(self, conf, session, topic): + """init a 'fanout' publisher. + """ + super(FanoutPublisher, self).__init__( + session, + "%s_fanout" % topic, {"type": "fanout"}) + + +class NotifyPublisher(Publisher): + """Publisher class for notifications""" + def __init__(self, conf, session, topic): + """init a 'topic' publisher. + """ + exchange_name = rpc_amqp.get_control_exchange(conf) + super(NotifyPublisher, self).__init__(session, + "%s/%s" % (exchange_name, topic), + {"durable": True}) + + +class Connection(object): + """Connection object.""" + + pool = None + + def __init__(self, conf, server_params=None): + if not qpid_messaging: + raise ImportError("Failed to import qpid.messaging") + + self.session = None + self.consumers = {} + self.consumer_thread = None + self.proxy_callbacks = [] + self.conf = conf + + if server_params and 'hostname' in server_params: + # NOTE(russellb) This enables support for cast_to_server. + server_params['qpid_hosts'] = [ + '%s:%d' % (server_params['hostname'], + server_params.get('port', 5672)) + ] + + params = { + 'qpid_hosts': self.conf.qpid_hosts, + 'username': self.conf.qpid_username, + 'password': self.conf.qpid_password, + } + params.update(server_params or {}) + + self.brokers = params['qpid_hosts'] + self.username = params['username'] + self.password = params['password'] + self.connection_create(self.brokers[0]) + self.reconnect() + + def connection_create(self, broker): + # Create the connection - this does not open the connection + self.connection = qpid_messaging.Connection(broker) + + # Check if flags are set and if so set them for the connection + # before we call open + self.connection.username = self.username + self.connection.password = self.password + + self.connection.sasl_mechanisms = self.conf.qpid_sasl_mechanisms + # Reconnection is done by self.reconnect() + self.connection.reconnect = False + self.connection.heartbeat = self.conf.qpid_heartbeat + self.connection.transport = self.conf.qpid_protocol + self.connection.tcp_nodelay = self.conf.qpid_tcp_nodelay + + def _register_consumer(self, consumer): + self.consumers[str(consumer.get_receiver())] = consumer + + def _lookup_consumer(self, receiver): + return self.consumers[str(receiver)] + + def reconnect(self): + """Handles reconnecting and re-establishing sessions and queues""" + attempt = 0 + delay = 1 + while True: + # Close the session if necessary + if self.connection.opened(): + try: + self.connection.close() + except qpid_exceptions.ConnectionError: + pass + + broker = self.brokers[attempt % len(self.brokers)] + attempt += 1 + + try: + self.connection_create(broker) + self.connection.open() + except qpid_exceptions.ConnectionError as e: + msg_dict = dict(e=e, delay=delay) + msg = _("Unable to connect to AMQP server: %(e)s. " + "Sleeping %(delay)s seconds") % msg_dict + LOG.error(msg) + time.sleep(delay) + delay = min(2 * delay, 60) + else: + LOG.info(_('Connected to AMQP server on %s'), broker) + break + + self.session = self.connection.session() + + if self.consumers: + consumers = self.consumers + self.consumers = {} + + for consumer in consumers.itervalues(): + consumer.reconnect(self.session) + self._register_consumer(consumer) + + LOG.debug(_("Re-established AMQP queues")) + + def ensure(self, error_callback, method, *args, **kwargs): + while True: + try: + return method(*args, **kwargs) + except (qpid_exceptions.Empty, + qpid_exceptions.ConnectionError), e: + if error_callback: + error_callback(e) + self.reconnect() + + def close(self): + """Close/release this connection""" + self.cancel_consumer_thread() + self.wait_on_proxy_callbacks() + self.connection.close() + self.connection = None + + def reset(self): + """Reset a connection so it can be used again""" + self.cancel_consumer_thread() + self.wait_on_proxy_callbacks() + self.session.close() + self.session = self.connection.session() + self.consumers = {} + + def declare_consumer(self, consumer_cls, topic, callback): + """Create a Consumer using the class that was passed in and + add it to our list of consumers + """ + def _connect_error(exc): + log_info = {'topic': topic, 'err_str': str(exc)} + LOG.error(_("Failed to declare consumer for topic '%(topic)s': " + "%(err_str)s") % log_info) + + def _declare_consumer(): + consumer = consumer_cls(self.conf, self.session, topic, callback) + self._register_consumer(consumer) + return consumer + + return self.ensure(_connect_error, _declare_consumer) + + def iterconsume(self, limit=None, timeout=None): + """Return an iterator that will consume from all queues/consumers""" + + def _error_callback(exc): + if isinstance(exc, qpid_exceptions.Empty): + LOG.debug(_('Timed out waiting for RPC response: %s') % + str(exc)) + raise rpc_common.Timeout() + else: + LOG.exception(_('Failed to consume message from queue: %s') % + str(exc)) + + def _consume(): + nxt_receiver = self.session.next_receiver(timeout=timeout) + try: + self._lookup_consumer(nxt_receiver).consume() + except Exception: + LOG.exception(_("Error processing message. Skipping it.")) + + for iteration in itertools.count(0): + if limit and iteration >= limit: + raise StopIteration + yield self.ensure(_error_callback, _consume) + + def cancel_consumer_thread(self): + """Cancel a consumer thread""" + if self.consumer_thread is not None: + self.consumer_thread.kill() + try: + self.consumer_thread.wait() + except greenlet.GreenletExit: + pass + self.consumer_thread = None + + def wait_on_proxy_callbacks(self): + """Wait for all proxy callback threads to exit.""" + for proxy_cb in self.proxy_callbacks: + proxy_cb.wait() + + def publisher_send(self, cls, topic, msg): + """Send to a publisher based on the publisher class""" + + def _connect_error(exc): + log_info = {'topic': topic, 'err_str': str(exc)} + LOG.exception(_("Failed to publish message to topic " + "'%(topic)s': %(err_str)s") % log_info) + + def _publisher_send(): + publisher = cls(self.conf, self.session, topic) + publisher.send(msg) + + return self.ensure(_connect_error, _publisher_send) + + def declare_direct_consumer(self, topic, callback): + """Create a 'direct' queue. + In nova's use, this is generally a msg_id queue used for + responses for call/multicall + """ + self.declare_consumer(DirectConsumer, topic, callback) + + def declare_topic_consumer(self, topic, callback=None, queue_name=None, + exchange_name=None): + """Create a 'topic' consumer.""" + self.declare_consumer(functools.partial(TopicConsumer, + name=queue_name, + exchange_name=exchange_name, + ), + topic, callback) + + def declare_fanout_consumer(self, topic, callback): + """Create a 'fanout' consumer""" + self.declare_consumer(FanoutConsumer, topic, callback) + + def direct_send(self, msg_id, msg): + """Send a 'direct' message""" + self.publisher_send(DirectPublisher, msg_id, msg) + + def topic_send(self, topic, msg, timeout=None): + """Send a 'topic' message""" + # + # We want to create a message with attributes, e.g. a TTL. We + # don't really need to keep 'msg' in its JSON format any longer + # so let's create an actual qpid message here and get some + # value-add on the go. + # + # WARNING: Request timeout happens to be in the same units as + # qpid's TTL (seconds). If this changes in the future, then this + # will need to be altered accordingly. + # + qpid_message = qpid_messaging.Message(content=msg, ttl=timeout) + self.publisher_send(TopicPublisher, topic, qpid_message) + + def fanout_send(self, topic, msg): + """Send a 'fanout' message""" + self.publisher_send(FanoutPublisher, topic, msg) + + def notify_send(self, topic, msg, **kwargs): + """Send a notify message on a topic""" + self.publisher_send(NotifyPublisher, topic, msg) + + def consume(self, limit=None): + """Consume from all queues/consumers""" + it = self.iterconsume(limit=limit) + while True: + try: + it.next() + except StopIteration: + return + + def consume_in_thread(self): + """Consumer from all queues/consumers in a greenthread""" + def _consumer_thread(): + try: + self.consume() + except greenlet.GreenletExit: + return + if self.consumer_thread is None: + self.consumer_thread = eventlet.spawn(_consumer_thread) + return self.consumer_thread + + def create_consumer(self, topic, proxy, fanout=False): + """Create a consumer that calls a method in a proxy object""" + proxy_cb = rpc_amqp.ProxyCallback( + self.conf, proxy, + rpc_amqp.get_connection_pool(self.conf, Connection)) + self.proxy_callbacks.append(proxy_cb) + + if fanout: + consumer = FanoutConsumer(self.conf, self.session, topic, proxy_cb) + else: + consumer = TopicConsumer(self.conf, self.session, topic, proxy_cb) + + self._register_consumer(consumer) + + return consumer + + def create_worker(self, topic, proxy, pool_name): + """Create a worker that calls a method in a proxy object""" + proxy_cb = rpc_amqp.ProxyCallback( + self.conf, proxy, + rpc_amqp.get_connection_pool(self.conf, Connection)) + self.proxy_callbacks.append(proxy_cb) + + consumer = TopicConsumer(self.conf, self.session, topic, proxy_cb, + name=pool_name) + + self._register_consumer(consumer) + + return consumer + + def join_consumer_pool(self, callback, pool_name, topic, + exchange_name=None): + """Register as a member of a group of consumers for a given topic from + the specified exchange. + + Exactly one member of a given pool will receive each message. + + A message will be delivered to multiple pools, if more than + one is created. + """ + callback_wrapper = rpc_amqp.CallbackWrapper( + conf=self.conf, + callback=callback, + connection_pool=rpc_amqp.get_connection_pool(self.conf, + Connection), + ) + self.proxy_callbacks.append(callback_wrapper) + + consumer = TopicConsumer(conf=self.conf, + session=self.session, + topic=topic, + callback=callback_wrapper, + name=pool_name, + exchange_name=exchange_name) + + self._register_consumer(consumer) + return consumer + + +def create_connection(conf, new=True): + """Create a connection""" + return rpc_amqp.create_connection( + conf, new, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def multicall(conf, context, topic, msg, timeout=None): + """Make a call that returns multiple times.""" + return rpc_amqp.multicall( + conf, context, topic, msg, timeout, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def call(conf, context, topic, msg, timeout=None): + """Sends a message on a topic and wait for a response.""" + return rpc_amqp.call( + conf, context, topic, msg, timeout, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def cast(conf, context, topic, msg): + """Sends a message on a topic without waiting for a response.""" + return rpc_amqp.cast( + conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def fanout_cast(conf, context, topic, msg): + """Sends a message on a fanout exchange without waiting for a response.""" + return rpc_amqp.fanout_cast( + conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def cast_to_server(conf, context, server_params, topic, msg): + """Sends a message on a topic to a specific server.""" + return rpc_amqp.cast_to_server( + conf, context, server_params, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def fanout_cast_to_server(conf, context, server_params, topic, msg): + """Sends a message on a fanout exchange to a specific server.""" + return rpc_amqp.fanout_cast_to_server( + conf, context, server_params, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def notify(conf, context, topic, msg, envelope): + """Sends a notification event on a topic.""" + return rpc_amqp.notify(conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection), + envelope) + + +def cleanup(): + return rpc_amqp.cleanup(Connection.pool) diff --git a/staccato/openstack/common/rpc/impl_zmq.py b/staccato/openstack/common/rpc/impl_zmq.py new file mode 100644 index 0000000..a21b3df --- /dev/null +++ b/staccato/openstack/common/rpc/impl_zmq.py @@ -0,0 +1,851 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 Cloudscaling Group, Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import pprint +import re +import socket +import sys +import types +import uuid + +import eventlet +import greenlet +from oslo.config import cfg + +from staccato.openstack.common import excutils +from staccato.openstack.common.gettextutils import _ +from staccato.openstack.common import importutils +from staccato.openstack.common import jsonutils +from staccato.openstack.common import processutils as utils +from staccato.openstack.common.rpc import common as rpc_common + +zmq = importutils.try_import('eventlet.green.zmq') + +# for convenience, are not modified. +pformat = pprint.pformat +Timeout = eventlet.timeout.Timeout +LOG = rpc_common.LOG +RemoteError = rpc_common.RemoteError +RPCException = rpc_common.RPCException + +zmq_opts = [ + cfg.StrOpt('rpc_zmq_bind_address', default='*', + help='ZeroMQ bind address. Should be a wildcard (*), ' + 'an ethernet interface, or IP. ' + 'The "host" option should point or resolve to this ' + 'address.'), + + # The module.Class to use for matchmaking. + cfg.StrOpt( + 'rpc_zmq_matchmaker', + default=('staccato.openstack.common.rpc.' + 'matchmaker.MatchMakerLocalhost'), + help='MatchMaker driver', + ), + + # The following port is unassigned by IANA as of 2012-05-21 + cfg.IntOpt('rpc_zmq_port', default=9501, + help='ZeroMQ receiver listening port'), + + cfg.IntOpt('rpc_zmq_contexts', default=1, + help='Number of ZeroMQ contexts, defaults to 1'), + + cfg.IntOpt('rpc_zmq_topic_backlog', default=None, + help='Maximum number of ingress messages to locally buffer ' + 'per topic. Default is unlimited.'), + + cfg.StrOpt('rpc_zmq_ipc_dir', default='/var/run/openstack', + help='Directory for holding IPC sockets'), + + cfg.StrOpt('rpc_zmq_host', default=socket.gethostname(), + help='Name of this node. Must be a valid hostname, FQDN, or ' + 'IP address. Must match "host" option, if running Nova.') +] + + +CONF = cfg.CONF +CONF.register_opts(zmq_opts) + +ZMQ_CTX = None # ZeroMQ Context, must be global. +matchmaker = None # memoized matchmaker object + + +def _serialize(data): + """ + Serialization wrapper + We prefer using JSON, but it cannot encode all types. + Error if a developer passes us bad data. + """ + try: + return jsonutils.dumps(data, ensure_ascii=True) + except TypeError: + with excutils.save_and_reraise_exception(): + LOG.error(_("JSON serialization failed.")) + + +def _deserialize(data): + """ + Deserialization wrapper + """ + LOG.debug(_("Deserializing: %s"), data) + return jsonutils.loads(data) + + +class ZmqSocket(object): + """ + A tiny wrapper around ZeroMQ to simplify the send/recv protocol + and connection management. + + Can be used as a Context (supports the 'with' statement). + """ + + def __init__(self, addr, zmq_type, bind=True, subscribe=None): + self.sock = _get_ctxt().socket(zmq_type) + self.addr = addr + self.type = zmq_type + self.subscriptions = [] + + # Support failures on sending/receiving on wrong socket type. + self.can_recv = zmq_type in (zmq.PULL, zmq.SUB) + self.can_send = zmq_type in (zmq.PUSH, zmq.PUB) + self.can_sub = zmq_type in (zmq.SUB, ) + + # Support list, str, & None for subscribe arg (cast to list) + do_sub = { + list: subscribe, + str: [subscribe], + type(None): [] + }[type(subscribe)] + + for f in do_sub: + self.subscribe(f) + + str_data = {'addr': addr, 'type': self.socket_s(), + 'subscribe': subscribe, 'bind': bind} + + LOG.debug(_("Connecting to %(addr)s with %(type)s"), str_data) + LOG.debug(_("-> Subscribed to %(subscribe)s"), str_data) + LOG.debug(_("-> bind: %(bind)s"), str_data) + + try: + if bind: + self.sock.bind(addr) + else: + self.sock.connect(addr) + except Exception: + raise RPCException(_("Could not open socket.")) + + def socket_s(self): + """Get socket type as string.""" + t_enum = ('PUSH', 'PULL', 'PUB', 'SUB', 'REP', 'REQ', 'ROUTER', + 'DEALER') + return dict(map(lambda t: (getattr(zmq, t), t), t_enum))[self.type] + + def subscribe(self, msg_filter): + """Subscribe.""" + if not self.can_sub: + raise RPCException("Cannot subscribe on this socket.") + LOG.debug(_("Subscribing to %s"), msg_filter) + + try: + self.sock.setsockopt(zmq.SUBSCRIBE, msg_filter) + except Exception: + return + + self.subscriptions.append(msg_filter) + + def unsubscribe(self, msg_filter): + """Unsubscribe.""" + if msg_filter not in self.subscriptions: + return + self.sock.setsockopt(zmq.UNSUBSCRIBE, msg_filter) + self.subscriptions.remove(msg_filter) + + def close(self): + if self.sock is None or self.sock.closed: + return + + # We must unsubscribe, or we'll leak descriptors. + if len(self.subscriptions) > 0: + for f in self.subscriptions: + try: + self.sock.setsockopt(zmq.UNSUBSCRIBE, f) + except Exception: + pass + self.subscriptions = [] + + try: + # Default is to linger + self.sock.close() + except Exception: + # While this is a bad thing to happen, + # it would be much worse if some of the code calling this + # were to fail. For now, lets log, and later evaluate + # if we can safely raise here. + LOG.error("ZeroMQ socket could not be closed.") + self.sock = None + + def recv(self): + if not self.can_recv: + raise RPCException(_("You cannot recv on this socket.")) + return self.sock.recv_multipart() + + def send(self, data): + if not self.can_send: + raise RPCException(_("You cannot send on this socket.")) + self.sock.send_multipart(data) + + +class ZmqClient(object): + """Client for ZMQ sockets.""" + + def __init__(self, addr, socket_type=None, bind=False): + if socket_type is None: + socket_type = zmq.PUSH + self.outq = ZmqSocket(addr, socket_type, bind=bind) + + def cast(self, msg_id, topic, data, envelope=False): + msg_id = msg_id or 0 + + if not envelope: + self.outq.send(map(bytes, + (msg_id, topic, 'cast', _serialize(data)))) + return + + rpc_envelope = rpc_common.serialize_msg(data[1], envelope) + zmq_msg = reduce(lambda x, y: x + y, rpc_envelope.items()) + self.outq.send(map(bytes, + (msg_id, topic, 'impl_zmq_v2', data[0]) + zmq_msg)) + + def close(self): + self.outq.close() + + +class RpcContext(rpc_common.CommonRpcContext): + """Context that supports replying to a rpc.call.""" + def __init__(self, **kwargs): + self.replies = [] + super(RpcContext, self).__init__(**kwargs) + + def deepcopy(self): + values = self.to_dict() + values['replies'] = self.replies + return self.__class__(**values) + + def reply(self, reply=None, failure=None, ending=False): + if ending: + return + self.replies.append(reply) + + @classmethod + def marshal(self, ctx): + ctx_data = ctx.to_dict() + return _serialize(ctx_data) + + @classmethod + def unmarshal(self, data): + return RpcContext.from_dict(_deserialize(data)) + + +class InternalContext(object): + """Used by ConsumerBase as a private context for - methods.""" + + def __init__(self, proxy): + self.proxy = proxy + self.msg_waiter = None + + def _get_response(self, ctx, proxy, topic, data): + """Process a curried message and cast the result to topic.""" + LOG.debug(_("Running func with context: %s"), ctx.to_dict()) + data.setdefault('version', None) + data.setdefault('args', {}) + + try: + result = proxy.dispatch( + ctx, data['version'], data['method'], + data.get('namespace'), **data['args']) + return ConsumerBase.normalize_reply(result, ctx.replies) + except greenlet.GreenletExit: + # ignore these since they are just from shutdowns + pass + except rpc_common.ClientException as e: + LOG.debug(_("Expected exception during message handling (%s)") % + e._exc_info[1]) + return {'exc': + rpc_common.serialize_remote_exception(e._exc_info, + log_failure=False)} + except Exception: + LOG.error(_("Exception during message handling")) + return {'exc': + rpc_common.serialize_remote_exception(sys.exc_info())} + + def reply(self, ctx, proxy, + msg_id=None, context=None, topic=None, msg=None): + """Reply to a casted call.""" + # NOTE(ewindisch): context kwarg exists for Grizzly compat. + # this may be able to be removed earlier than + # 'I' if ConsumerBase.process were refactored. + if type(msg) is list: + payload = msg[-1] + else: + payload = msg + + response = ConsumerBase.normalize_reply( + self._get_response(ctx, proxy, topic, payload), + ctx.replies) + + LOG.debug(_("Sending reply")) + _multi_send(_cast, ctx, topic, { + 'method': '-process_reply', + 'args': { + 'msg_id': msg_id, # Include for Folsom compat. + 'response': response + } + }, _msg_id=msg_id) + + +class ConsumerBase(object): + """Base Consumer.""" + + def __init__(self): + self.private_ctx = InternalContext(None) + + @classmethod + def normalize_reply(self, result, replies): + #TODO(ewindisch): re-evaluate and document this method. + if isinstance(result, types.GeneratorType): + return list(result) + elif replies: + return replies + else: + return [result] + + def process(self, proxy, ctx, data): + data.setdefault('version', None) + data.setdefault('args', {}) + + # Method starting with - are + # processed internally. (non-valid method name) + method = data.get('method') + if not method: + LOG.error(_("RPC message did not include method.")) + return + + # Internal method + # uses internal context for safety. + if method == '-reply': + self.private_ctx.reply(ctx, proxy, **data['args']) + return + + proxy.dispatch(ctx, data['version'], + data['method'], data.get('namespace'), **data['args']) + + +class ZmqBaseReactor(ConsumerBase): + """ + A consumer class implementing a + centralized casting broker (PULL-PUSH) + for RoundRobin requests. + """ + + def __init__(self, conf): + super(ZmqBaseReactor, self).__init__() + + self.mapping = {} + self.proxies = {} + self.threads = [] + self.sockets = [] + self.subscribe = {} + + self.pool = eventlet.greenpool.GreenPool(conf.rpc_thread_pool_size) + + def register(self, proxy, in_addr, zmq_type_in, out_addr=None, + zmq_type_out=None, in_bind=True, out_bind=True, + subscribe=None): + + LOG.info(_("Registering reactor")) + + if zmq_type_in not in (zmq.PULL, zmq.SUB): + raise RPCException("Bad input socktype") + + # Items push in. + inq = ZmqSocket(in_addr, zmq_type_in, bind=in_bind, + subscribe=subscribe) + + self.proxies[inq] = proxy + self.sockets.append(inq) + + LOG.info(_("In reactor registered")) + + if not out_addr: + return + + if zmq_type_out not in (zmq.PUSH, zmq.PUB): + raise RPCException("Bad output socktype") + + # Items push out. + outq = ZmqSocket(out_addr, zmq_type_out, bind=out_bind) + + self.mapping[inq] = outq + self.mapping[outq] = inq + self.sockets.append(outq) + + LOG.info(_("Out reactor registered")) + + def consume_in_thread(self): + def _consume(sock): + LOG.info(_("Consuming socket")) + while True: + self.consume(sock) + + for k in self.proxies.keys(): + self.threads.append( + self.pool.spawn(_consume, k) + ) + + def wait(self): + for t in self.threads: + t.wait() + + def close(self): + for s in self.sockets: + s.close() + + for t in self.threads: + t.kill() + + +class ZmqProxy(ZmqBaseReactor): + """ + A consumer class implementing a + topic-based proxy, forwarding to + IPC sockets. + """ + + def __init__(self, conf): + super(ZmqProxy, self).__init__(conf) + pathsep = set((os.path.sep or '', os.path.altsep or '', '/', '\\')) + self.badchars = re.compile(r'[%s]' % re.escape(''.join(pathsep))) + + self.topic_proxy = {} + + def consume(self, sock): + ipc_dir = CONF.rpc_zmq_ipc_dir + + #TODO(ewindisch): use zero-copy (i.e. references, not copying) + data = sock.recv() + topic = data[1] + + LOG.debug(_("CONSUMER GOT %s"), ' '.join(map(pformat, data))) + + if topic.startswith('fanout~'): + sock_type = zmq.PUB + topic = topic.split('.', 1)[0] + elif topic.startswith('zmq_replies'): + sock_type = zmq.PUB + else: + sock_type = zmq.PUSH + + if topic not in self.topic_proxy: + def publisher(waiter): + LOG.info(_("Creating proxy for topic: %s"), topic) + + try: + # The topic is received over the network, + # don't trust this input. + if self.badchars.search(topic) is not None: + emsg = _("Topic contained dangerous characters.") + LOG.warn(emsg) + raise RPCException(emsg) + + out_sock = ZmqSocket("ipc://%s/zmq_topic_%s" % + (ipc_dir, topic), + sock_type, bind=True) + except RPCException: + waiter.send_exception(*sys.exc_info()) + return + + self.topic_proxy[topic] = eventlet.queue.LightQueue( + CONF.rpc_zmq_topic_backlog) + self.sockets.append(out_sock) + + # It takes some time for a pub socket to open, + # before we can have any faith in doing a send() to it. + if sock_type == zmq.PUB: + eventlet.sleep(.5) + + waiter.send(True) + + while(True): + data = self.topic_proxy[topic].get() + out_sock.send(data) + LOG.debug(_("ROUTER RELAY-OUT SUCCEEDED %(data)s") % + {'data': data}) + + wait_sock_creation = eventlet.event.Event() + eventlet.spawn(publisher, wait_sock_creation) + + try: + wait_sock_creation.wait() + except RPCException: + LOG.error(_("Topic socket file creation failed.")) + return + + try: + self.topic_proxy[topic].put_nowait(data) + LOG.debug(_("ROUTER RELAY-OUT QUEUED %(data)s") % + {'data': data}) + except eventlet.queue.Full: + LOG.error(_("Local per-topic backlog buffer full for topic " + "%(topic)s. Dropping message.") % {'topic': topic}) + + def consume_in_thread(self): + """Runs the ZmqProxy service""" + ipc_dir = CONF.rpc_zmq_ipc_dir + consume_in = "tcp://%s:%s" % \ + (CONF.rpc_zmq_bind_address, + CONF.rpc_zmq_port) + consumption_proxy = InternalContext(None) + + if not os.path.isdir(ipc_dir): + try: + utils.execute('mkdir', '-p', ipc_dir, run_as_root=True) + utils.execute('chown', "%s:%s" % (os.getuid(), os.getgid()), + ipc_dir, run_as_root=True) + utils.execute('chmod', '750', ipc_dir, run_as_root=True) + except utils.ProcessExecutionError: + with excutils.save_and_reraise_exception(): + LOG.error(_("Could not create IPC directory %s") % + (ipc_dir, )) + + try: + self.register(consumption_proxy, + consume_in, + zmq.PULL, + out_bind=True) + except zmq.ZMQError: + with excutils.save_and_reraise_exception(): + LOG.error(_("Could not create ZeroMQ receiver daemon. " + "Socket may already be in use.")) + + super(ZmqProxy, self).consume_in_thread() + + +def unflatten_envelope(packenv): + """Unflattens the RPC envelope. + Takes a list and returns a dictionary. + i.e. [1,2,3,4] => {1: 2, 3: 4} + """ + i = iter(packenv) + h = {} + try: + while True: + k = i.next() + h[k] = i.next() + except StopIteration: + return h + + +class ZmqReactor(ZmqBaseReactor): + """ + A consumer class implementing a + consumer for messages. Can also be + used as a 1:1 proxy + """ + + def __init__(self, conf): + super(ZmqReactor, self).__init__(conf) + + def consume(self, sock): + #TODO(ewindisch): use zero-copy (i.e. references, not copying) + data = sock.recv() + LOG.debug(_("CONSUMER RECEIVED DATA: %s"), data) + if sock in self.mapping: + LOG.debug(_("ROUTER RELAY-OUT %(data)s") % { + 'data': data}) + self.mapping[sock].send(data) + return + + proxy = self.proxies[sock] + + if data[2] == 'cast': # Legacy protocol + packenv = data[3] + + ctx, msg = _deserialize(packenv) + request = rpc_common.deserialize_msg(msg) + ctx = RpcContext.unmarshal(ctx) + elif data[2] == 'impl_zmq_v2': + packenv = data[4:] + + msg = unflatten_envelope(packenv) + request = rpc_common.deserialize_msg(msg) + + # Unmarshal only after verifying the message. + ctx = RpcContext.unmarshal(data[3]) + else: + LOG.error(_("ZMQ Envelope version unsupported or unknown.")) + return + + self.pool.spawn_n(self.process, proxy, ctx, request) + + +class Connection(rpc_common.Connection): + """Manages connections and threads.""" + + def __init__(self, conf): + self.topics = [] + self.reactor = ZmqReactor(conf) + + def create_consumer(self, topic, proxy, fanout=False): + # Register with matchmaker. + _get_matchmaker().register(topic, CONF.rpc_zmq_host) + + # Subscription scenarios + if fanout: + sock_type = zmq.SUB + subscribe = ('', fanout)[type(fanout) == str] + topic = 'fanout~' + topic.split('.', 1)[0] + else: + sock_type = zmq.PULL + subscribe = None + topic = '.'.join((topic.split('.', 1)[0], CONF.rpc_zmq_host)) + + if topic in self.topics: + LOG.info(_("Skipping topic registration. Already registered.")) + return + + # Receive messages from (local) proxy + inaddr = "ipc://%s/zmq_topic_%s" % \ + (CONF.rpc_zmq_ipc_dir, topic) + + LOG.debug(_("Consumer is a zmq.%s"), + ['PULL', 'SUB'][sock_type == zmq.SUB]) + + self.reactor.register(proxy, inaddr, sock_type, + subscribe=subscribe, in_bind=False) + self.topics.append(topic) + + def close(self): + _get_matchmaker().stop_heartbeat() + for topic in self.topics: + _get_matchmaker().unregister(topic, CONF.rpc_zmq_host) + + self.reactor.close() + self.topics = [] + + def wait(self): + self.reactor.wait() + + def consume_in_thread(self): + _get_matchmaker().start_heartbeat() + self.reactor.consume_in_thread() + + +def _cast(addr, context, topic, msg, timeout=None, envelope=False, + _msg_id=None): + timeout_cast = timeout or CONF.rpc_cast_timeout + payload = [RpcContext.marshal(context), msg] + + with Timeout(timeout_cast, exception=rpc_common.Timeout): + try: + conn = ZmqClient(addr) + + # assumes cast can't return an exception + conn.cast(_msg_id, topic, payload, envelope) + except zmq.ZMQError: + raise RPCException("Cast failed. ZMQ Socket Exception") + finally: + if 'conn' in vars(): + conn.close() + + +def _call(addr, context, topic, msg, timeout=None, + envelope=False): + # timeout_response is how long we wait for a response + timeout = timeout or CONF.rpc_response_timeout + + # The msg_id is used to track replies. + msg_id = uuid.uuid4().hex + + # Replies always come into the reply service. + reply_topic = "zmq_replies.%s" % CONF.rpc_zmq_host + + LOG.debug(_("Creating payload")) + # Curry the original request into a reply method. + mcontext = RpcContext.marshal(context) + payload = { + 'method': '-reply', + 'args': { + 'msg_id': msg_id, + 'topic': reply_topic, + # TODO(ewindisch): safe to remove mcontext in I. + 'msg': [mcontext, msg] + } + } + + LOG.debug(_("Creating queue socket for reply waiter")) + + # Messages arriving async. + # TODO(ewindisch): have reply consumer with dynamic subscription mgmt + with Timeout(timeout, exception=rpc_common.Timeout): + try: + msg_waiter = ZmqSocket( + "ipc://%s/zmq_topic_zmq_replies.%s" % + (CONF.rpc_zmq_ipc_dir, + CONF.rpc_zmq_host), + zmq.SUB, subscribe=msg_id, bind=False + ) + + LOG.debug(_("Sending cast")) + _cast(addr, context, topic, payload, envelope) + + LOG.debug(_("Cast sent; Waiting reply")) + # Blocks until receives reply + msg = msg_waiter.recv() + LOG.debug(_("Received message: %s"), msg) + LOG.debug(_("Unpacking response")) + + if msg[2] == 'cast': # Legacy version + raw_msg = _deserialize(msg[-1])[-1] + elif msg[2] == 'impl_zmq_v2': + rpc_envelope = unflatten_envelope(msg[4:]) + raw_msg = rpc_common.deserialize_msg(rpc_envelope) + else: + raise rpc_common.UnsupportedRpcEnvelopeVersion( + _("Unsupported or unknown ZMQ envelope returned.")) + + responses = raw_msg['args']['response'] + # ZMQError trumps the Timeout error. + except zmq.ZMQError: + raise RPCException("ZMQ Socket Error") + except (IndexError, KeyError): + raise RPCException(_("RPC Message Invalid.")) + finally: + if 'msg_waiter' in vars(): + msg_waiter.close() + + # It seems we don't need to do all of the following, + # but perhaps it would be useful for multicall? + # One effect of this is that we're checking all + # responses for Exceptions. + for resp in responses: + if isinstance(resp, types.DictType) and 'exc' in resp: + raise rpc_common.deserialize_remote_exception(CONF, resp['exc']) + + return responses[-1] + + +def _multi_send(method, context, topic, msg, timeout=None, + envelope=False, _msg_id=None): + """ + Wraps the sending of messages, + dispatches to the matchmaker and sends + message to all relevant hosts. + """ + conf = CONF + LOG.debug(_("%(msg)s") % {'msg': ' '.join(map(pformat, (topic, msg)))}) + + queues = _get_matchmaker().queues(topic) + LOG.debug(_("Sending message(s) to: %s"), queues) + + # Don't stack if we have no matchmaker results + if len(queues) == 0: + LOG.warn(_("No matchmaker results. Not casting.")) + # While not strictly a timeout, callers know how to handle + # this exception and a timeout isn't too big a lie. + raise rpc_common.Timeout(_("No match from matchmaker.")) + + # This supports brokerless fanout (addresses > 1) + for queue in queues: + (_topic, ip_addr) = queue + _addr = "tcp://%s:%s" % (ip_addr, conf.rpc_zmq_port) + + if method.__name__ == '_cast': + eventlet.spawn_n(method, _addr, context, + _topic, msg, timeout, envelope, + _msg_id) + return + return method(_addr, context, _topic, msg, timeout, + envelope) + + +def create_connection(conf, new=True): + return Connection(conf) + + +def multicall(conf, *args, **kwargs): + """Multiple calls.""" + return _multi_send(_call, *args, **kwargs) + + +def call(conf, *args, **kwargs): + """Send a message, expect a response.""" + data = _multi_send(_call, *args, **kwargs) + return data[-1] + + +def cast(conf, *args, **kwargs): + """Send a message expecting no reply.""" + _multi_send(_cast, *args, **kwargs) + + +def fanout_cast(conf, context, topic, msg, **kwargs): + """Send a message to all listening and expect no reply.""" + # NOTE(ewindisch): fanout~ is used because it avoid splitting on . + # and acts as a non-subtle hint to the matchmaker and ZmqProxy. + _multi_send(_cast, context, 'fanout~' + str(topic), msg, **kwargs) + + +def notify(conf, context, topic, msg, envelope): + """ + Send notification event. + Notifications are sent to topic-priority. + This differs from the AMQP drivers which send to topic.priority. + """ + # NOTE(ewindisch): dot-priority in rpc notifier does not + # work with our assumptions. + topic = topic.replace('.', '-') + cast(conf, context, topic, msg, envelope=envelope) + + +def cleanup(): + """Clean up resources in use by implementation.""" + global ZMQ_CTX + if ZMQ_CTX: + ZMQ_CTX.term() + ZMQ_CTX = None + + global matchmaker + matchmaker = None + + +def _get_ctxt(): + if not zmq: + raise ImportError("Failed to import eventlet.green.zmq") + + global ZMQ_CTX + if not ZMQ_CTX: + ZMQ_CTX = zmq.Context(CONF.rpc_zmq_contexts) + return ZMQ_CTX + + +def _get_matchmaker(*args, **kwargs): + global matchmaker + if not matchmaker: + matchmaker = importutils.import_object( + CONF.rpc_zmq_matchmaker, *args, **kwargs) + return matchmaker diff --git a/staccato/openstack/common/rpc/matchmaker.py b/staccato/openstack/common/rpc/matchmaker.py new file mode 100644 index 0000000..e409a64 --- /dev/null +++ b/staccato/openstack/common/rpc/matchmaker.py @@ -0,0 +1,425 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 Cloudscaling Group, Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +""" +The MatchMaker classes should except a Topic or Fanout exchange key and +return keys for direct exchanges, per (approximate) AMQP parlance. +""" + +import contextlib +import itertools +import json + +import eventlet +from oslo.config import cfg + +from staccato.openstack.common.gettextutils import _ +from staccato.openstack.common import log as logging + + +matchmaker_opts = [ + # Matchmaker ring file + cfg.StrOpt('matchmaker_ringfile', + default='/etc/nova/matchmaker_ring.json', + help='Matchmaker ring file (JSON)'), + cfg.IntOpt('matchmaker_heartbeat_freq', + default=300, + help='Heartbeat frequency'), + cfg.IntOpt('matchmaker_heartbeat_ttl', + default=600, + help='Heartbeat time-to-live.'), +] + +CONF = cfg.CONF +CONF.register_opts(matchmaker_opts) +LOG = logging.getLogger(__name__) +contextmanager = contextlib.contextmanager + + +class MatchMakerException(Exception): + """Signified a match could not be found.""" + message = _("Match not found by MatchMaker.") + + +class Exchange(object): + """ + Implements lookups. + Subclass this to support hashtables, dns, etc. + """ + def __init__(self): + pass + + def run(self, key): + raise NotImplementedError() + + +class Binding(object): + """ + A binding on which to perform a lookup. + """ + def __init__(self): + pass + + def test(self, key): + raise NotImplementedError() + + +class MatchMakerBase(object): + """ + Match Maker Base Class. + Build off HeartbeatMatchMakerBase if building a + heartbeat-capable MatchMaker. + """ + def __init__(self): + # Array of tuples. Index [2] toggles negation, [3] is last-if-true + self.bindings = [] + + self.no_heartbeat_msg = _('Matchmaker does not implement ' + 'registration or heartbeat.') + + def register(self, key, host): + """ + Register a host on a backend. + Heartbeats, if applicable, may keepalive registration. + """ + pass + + def ack_alive(self, key, host): + """ + Acknowledge that a key.host is alive. + Used internally for updating heartbeats, + but may also be used publically to acknowledge + a system is alive (i.e. rpc message successfully + sent to host) + """ + pass + + def is_alive(self, topic, host): + """ + Checks if a host is alive. + """ + pass + + def expire(self, topic, host): + """ + Explicitly expire a host's registration. + """ + pass + + def send_heartbeats(self): + """ + Send all heartbeats. + Use start_heartbeat to spawn a heartbeat greenthread, + which loops this method. + """ + pass + + def unregister(self, key, host): + """ + Unregister a topic. + """ + pass + + def start_heartbeat(self): + """ + Spawn heartbeat greenthread. + """ + pass + + def stop_heartbeat(self): + """ + Destroys the heartbeat greenthread. + """ + pass + + def add_binding(self, binding, rule, last=True): + self.bindings.append((binding, rule, False, last)) + + #NOTE(ewindisch): kept the following method in case we implement the + # underlying support. + #def add_negate_binding(self, binding, rule, last=True): + # self.bindings.append((binding, rule, True, last)) + + def queues(self, key): + workers = [] + + # bit is for negate bindings - if we choose to implement it. + # last stops processing rules if this matches. + for (binding, exchange, bit, last) in self.bindings: + if binding.test(key): + workers.extend(exchange.run(key)) + + # Support last. + if last: + return workers + return workers + + +class HeartbeatMatchMakerBase(MatchMakerBase): + """ + Base for a heart-beat capable MatchMaker. + Provides common methods for registering, + unregistering, and maintaining heartbeats. + """ + def __init__(self): + self.hosts = set() + self._heart = None + self.host_topic = {} + + super(HeartbeatMatchMakerBase, self).__init__() + + def send_heartbeats(self): + """ + Send all heartbeats. + Use start_heartbeat to spawn a heartbeat greenthread, + which loops this method. + """ + for key, host in self.host_topic: + self.ack_alive(key, host) + + def ack_alive(self, key, host): + """ + Acknowledge that a host.topic is alive. + Used internally for updating heartbeats, + but may also be used publically to acknowledge + a system is alive (i.e. rpc message successfully + sent to host) + """ + raise NotImplementedError("Must implement ack_alive") + + def backend_register(self, key, host): + """ + Implements registration logic. + Called by register(self,key,host) + """ + raise NotImplementedError("Must implement backend_register") + + def backend_unregister(self, key, key_host): + """ + Implements de-registration logic. + Called by unregister(self,key,host) + """ + raise NotImplementedError("Must implement backend_unregister") + + def register(self, key, host): + """ + Register a host on a backend. + Heartbeats, if applicable, may keepalive registration. + """ + self.hosts.add(host) + self.host_topic[(key, host)] = host + key_host = '.'.join((key, host)) + + self.backend_register(key, key_host) + + self.ack_alive(key, host) + + def unregister(self, key, host): + """ + Unregister a topic. + """ + if (key, host) in self.host_topic: + del self.host_topic[(key, host)] + + self.hosts.discard(host) + self.backend_unregister(key, '.'.join((key, host))) + + LOG.info(_("Matchmaker unregistered: %s, %s" % (key, host))) + + def start_heartbeat(self): + """ + Implementation of MatchMakerBase.start_heartbeat + Launches greenthread looping send_heartbeats(), + yielding for CONF.matchmaker_heartbeat_freq seconds + between iterations. + """ + if len(self.hosts) == 0: + raise MatchMakerException( + _("Register before starting heartbeat.")) + + def do_heartbeat(): + while True: + self.send_heartbeats() + eventlet.sleep(CONF.matchmaker_heartbeat_freq) + + self._heart = eventlet.spawn(do_heartbeat) + + def stop_heartbeat(self): + """ + Destroys the heartbeat greenthread. + """ + if self._heart: + self._heart.kill() + + +class DirectBinding(Binding): + """ + Specifies a host in the key via a '.' character + Although dots are used in the key, the behavior here is + that it maps directly to a host, thus direct. + """ + def test(self, key): + if '.' in key: + return True + return False + + +class TopicBinding(Binding): + """ + Where a 'bare' key without dots. + AMQP generally considers topic exchanges to be those *with* dots, + but we deviate here in terminology as the behavior here matches + that of a topic exchange (whereas where there are dots, behavior + matches that of a direct exchange. + """ + def test(self, key): + if '.' not in key: + return True + return False + + +class FanoutBinding(Binding): + """Match on fanout keys, where key starts with 'fanout.' string.""" + def test(self, key): + if key.startswith('fanout~'): + return True + return False + + +class StubExchange(Exchange): + """Exchange that does nothing.""" + def run(self, key): + return [(key, None)] + + +class RingExchange(Exchange): + """ + Match Maker where hosts are loaded from a static file containing + a hashmap (JSON formatted). + + __init__ takes optional ring dictionary argument, otherwise + loads the ringfile from CONF.mathcmaker_ringfile. + """ + def __init__(self, ring=None): + super(RingExchange, self).__init__() + + if ring: + self.ring = ring + else: + fh = open(CONF.matchmaker_ringfile, 'r') + self.ring = json.load(fh) + fh.close() + + self.ring0 = {} + for k in self.ring.keys(): + self.ring0[k] = itertools.cycle(self.ring[k]) + + def _ring_has(self, key): + if key in self.ring0: + return True + return False + + +class RoundRobinRingExchange(RingExchange): + """A Topic Exchange based on a hashmap.""" + def __init__(self, ring=None): + super(RoundRobinRingExchange, self).__init__(ring) + + def run(self, key): + if not self._ring_has(key): + LOG.warn( + _("No key defining hosts for topic '%s', " + "see ringfile") % (key, ) + ) + return [] + host = next(self.ring0[key]) + return [(key + '.' + host, host)] + + +class FanoutRingExchange(RingExchange): + """Fanout Exchange based on a hashmap.""" + def __init__(self, ring=None): + super(FanoutRingExchange, self).__init__(ring) + + def run(self, key): + # Assume starts with "fanout~", strip it for lookup. + nkey = key.split('fanout~')[1:][0] + if not self._ring_has(nkey): + LOG.warn( + _("No key defining hosts for topic '%s', " + "see ringfile") % (nkey, ) + ) + return [] + return map(lambda x: (key + '.' + x, x), self.ring[nkey]) + + +class LocalhostExchange(Exchange): + """Exchange where all direct topics are local.""" + def __init__(self, host='localhost'): + self.host = host + super(Exchange, self).__init__() + + def run(self, key): + return [('.'.join((key.split('.')[0], self.host)), self.host)] + + +class DirectExchange(Exchange): + """ + Exchange where all topic keys are split, sending to second half. + i.e. "compute.host" sends a message to "compute.host" running on "host" + """ + def __init__(self): + super(Exchange, self).__init__() + + def run(self, key): + e = key.split('.', 1)[1] + return [(key, e)] + + +class MatchMakerRing(MatchMakerBase): + """ + Match Maker where hosts are loaded from a static hashmap. + """ + def __init__(self, ring=None): + super(MatchMakerRing, self).__init__() + self.add_binding(FanoutBinding(), FanoutRingExchange(ring)) + self.add_binding(DirectBinding(), DirectExchange()) + self.add_binding(TopicBinding(), RoundRobinRingExchange(ring)) + + +class MatchMakerLocalhost(MatchMakerBase): + """ + Match Maker where all bare topics resolve to localhost. + Useful for testing. + """ + def __init__(self, host='localhost'): + super(MatchMakerLocalhost, self).__init__() + self.add_binding(FanoutBinding(), LocalhostExchange(host)) + self.add_binding(DirectBinding(), DirectExchange()) + self.add_binding(TopicBinding(), LocalhostExchange(host)) + + +class MatchMakerStub(MatchMakerBase): + """ + Match Maker where topics are untouched. + Useful for testing, or for AMQP/brokered queues. + Will not work where knowledge of hosts is known (i.e. zeromq) + """ + def __init__(self): + super(MatchMakerLocalhost, self).__init__() + + self.add_binding(FanoutBinding(), StubExchange()) + self.add_binding(DirectBinding(), StubExchange()) + self.add_binding(TopicBinding(), StubExchange()) diff --git a/staccato/openstack/common/rpc/matchmaker_redis.py b/staccato/openstack/common/rpc/matchmaker_redis.py new file mode 100644 index 0000000..a355fe8 --- /dev/null +++ b/staccato/openstack/common/rpc/matchmaker_redis.py @@ -0,0 +1,149 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 Cloudscaling Group, Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +""" +The MatchMaker classes should accept a Topic or Fanout exchange key and +return keys for direct exchanges, per (approximate) AMQP parlance. +""" + +from oslo.config import cfg + +from staccato.openstack.common import importutils +from staccato.openstack.common import log as logging +from staccato.openstack.common.rpc import matchmaker as mm_common + +redis = importutils.try_import('redis') + + +matchmaker_redis_opts = [ + cfg.StrOpt('host', + default='127.0.0.1', + help='Host to locate redis'), + cfg.IntOpt('port', + default=6379, + help='Use this port to connect to redis host.'), + cfg.StrOpt('password', + default=None, + help='Password for Redis server. (optional)'), +] + +CONF = cfg.CONF +opt_group = cfg.OptGroup(name='matchmaker_redis', + title='Options for Redis-based MatchMaker') +CONF.register_group(opt_group) +CONF.register_opts(matchmaker_redis_opts, opt_group) +LOG = logging.getLogger(__name__) + + +class RedisExchange(mm_common.Exchange): + def __init__(self, matchmaker): + self.matchmaker = matchmaker + self.redis = matchmaker.redis + super(RedisExchange, self).__init__() + + +class RedisTopicExchange(RedisExchange): + """ + Exchange where all topic keys are split, sending to second half. + i.e. "compute.host" sends a message to "compute" running on "host" + """ + def run(self, topic): + while True: + member_name = self.redis.srandmember(topic) + + if not member_name: + # If this happens, there are no + # longer any members. + break + + if not self.matchmaker.is_alive(topic, member_name): + continue + + host = member_name.split('.', 1)[1] + return [(member_name, host)] + return [] + + +class RedisFanoutExchange(RedisExchange): + """ + Return a list of all hosts. + """ + def run(self, topic): + topic = topic.split('~', 1)[1] + hosts = self.redis.smembers(topic) + good_hosts = filter( + lambda host: self.matchmaker.is_alive(topic, host), hosts) + + return [(x, x.split('.', 1)[1]) for x in good_hosts] + + +class MatchMakerRedis(mm_common.HeartbeatMatchMakerBase): + """ + MatchMaker registering and looking-up hosts with a Redis server. + """ + def __init__(self): + super(MatchMakerRedis, self).__init__() + + if not redis: + raise ImportError("Failed to import module redis.") + + self.redis = redis.StrictRedis( + host=CONF.matchmaker_redis.host, + port=CONF.matchmaker_redis.port, + password=CONF.matchmaker_redis.password) + + self.add_binding(mm_common.FanoutBinding(), RedisFanoutExchange(self)) + self.add_binding(mm_common.DirectBinding(), mm_common.DirectExchange()) + self.add_binding(mm_common.TopicBinding(), RedisTopicExchange(self)) + + def ack_alive(self, key, host): + topic = "%s.%s" % (key, host) + if not self.redis.expire(topic, CONF.matchmaker_heartbeat_ttl): + # If we could not update the expiration, the key + # might have been pruned. Re-register, creating a new + # key in Redis. + self.register(self.topic_host[host], host) + + def is_alive(self, topic, host): + if self.redis.ttl(host) == -1: + self.expire(topic, host) + return False + return True + + def expire(self, topic, host): + with self.redis.pipeline() as pipe: + pipe.multi() + pipe.delete(host) + pipe.srem(topic, host) + pipe.execute() + + def backend_register(self, key, key_host): + with self.redis.pipeline() as pipe: + pipe.multi() + pipe.sadd(key, key_host) + + # No value is needed, we just + # care if it exists. Sets aren't viable + # because only keys can expire. + pipe.set(key_host, '') + + pipe.execute() + + def backend_unregister(self, key, key_host): + with self.redis.pipeline() as pipe: + pipe.multi() + pipe.srem(key, key_host) + pipe.delete(key_host) + pipe.execute() diff --git a/staccato/openstack/common/rpc/proxy.py b/staccato/openstack/common/rpc/proxy.py new file mode 100644 index 0000000..828d79e --- /dev/null +++ b/staccato/openstack/common/rpc/proxy.py @@ -0,0 +1,179 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +A helper class for proxy objects to remote APIs. + +For more information about rpc API version numbers, see: + rpc/dispatcher.py +""" + + +from staccato.openstack.common import rpc + + +class RpcProxy(object): + """A helper class for rpc clients. + + This class is a wrapper around the RPC client API. It allows you to + specify the topic and API version in a single place. This is intended to + be used as a base class for a class that implements the client side of an + rpc API. + """ + + def __init__(self, topic, default_version): + """Initialize an RpcProxy. + + :param topic: The topic to use for all messages. + :param default_version: The default API version to request in all + outgoing messages. This can be overridden on a per-message + basis. + """ + self.topic = topic + self.default_version = default_version + super(RpcProxy, self).__init__() + + def _set_version(self, msg, vers): + """Helper method to set the version in a message. + + :param msg: The message having a version added to it. + :param vers: The version number to add to the message. + """ + msg['version'] = vers if vers else self.default_version + + def _get_topic(self, topic): + """Return the topic to use for a message.""" + return topic if topic else self.topic + + @staticmethod + def make_namespaced_msg(method, namespace, **kwargs): + return {'method': method, 'namespace': namespace, 'args': kwargs} + + @staticmethod + def make_msg(method, **kwargs): + return RpcProxy.make_namespaced_msg(method, None, **kwargs) + + def call(self, context, msg, topic=None, version=None, timeout=None): + """rpc.call() a remote method. + + :param context: The request context + :param msg: The message to send, including the method and args. + :param topic: Override the topic for this message. + :param version: (Optional) Override the requested API version in this + message. + :param timeout: (Optional) A timeout to use when waiting for the + response. If no timeout is specified, a default timeout will be + used that is usually sufficient. + + :returns: The return value from the remote method. + """ + self._set_version(msg, version) + real_topic = self._get_topic(topic) + try: + return rpc.call(context, real_topic, msg, timeout) + except rpc.common.Timeout as exc: + raise rpc.common.Timeout( + exc.info, real_topic, msg.get('method')) + + def multicall(self, context, msg, topic=None, version=None, timeout=None): + """rpc.multicall() a remote method. + + :param context: The request context + :param msg: The message to send, including the method and args. + :param topic: Override the topic for this message. + :param version: (Optional) Override the requested API version in this + message. + :param timeout: (Optional) A timeout to use when waiting for the + response. If no timeout is specified, a default timeout will be + used that is usually sufficient. + + :returns: An iterator that lets you process each of the returned values + from the remote method as they arrive. + """ + self._set_version(msg, version) + real_topic = self._get_topic(topic) + try: + return rpc.multicall(context, real_topic, msg, timeout) + except rpc.common.Timeout as exc: + raise rpc.common.Timeout( + exc.info, real_topic, msg.get('method')) + + def cast(self, context, msg, topic=None, version=None): + """rpc.cast() a remote method. + + :param context: The request context + :param msg: The message to send, including the method and args. + :param topic: Override the topic for this message. + :param version: (Optional) Override the requested API version in this + message. + + :returns: None. rpc.cast() does not wait on any return value from the + remote method. + """ + self._set_version(msg, version) + rpc.cast(context, self._get_topic(topic), msg) + + def fanout_cast(self, context, msg, topic=None, version=None): + """rpc.fanout_cast() a remote method. + + :param context: The request context + :param msg: The message to send, including the method and args. + :param topic: Override the topic for this message. + :param version: (Optional) Override the requested API version in this + message. + + :returns: None. rpc.fanout_cast() does not wait on any return value + from the remote method. + """ + self._set_version(msg, version) + rpc.fanout_cast(context, self._get_topic(topic), msg) + + def cast_to_server(self, context, server_params, msg, topic=None, + version=None): + """rpc.cast_to_server() a remote method. + + :param context: The request context + :param server_params: Server parameters. See rpc.cast_to_server() for + details. + :param msg: The message to send, including the method and args. + :param topic: Override the topic for this message. + :param version: (Optional) Override the requested API version in this + message. + + :returns: None. rpc.cast_to_server() does not wait on any + return values. + """ + self._set_version(msg, version) + rpc.cast_to_server(context, server_params, self._get_topic(topic), msg) + + def fanout_cast_to_server(self, context, server_params, msg, topic=None, + version=None): + """rpc.fanout_cast_to_server() a remote method. + + :param context: The request context + :param server_params: Server parameters. See rpc.cast_to_server() for + details. + :param msg: The message to send, including the method and args. + :param topic: Override the topic for this message. + :param version: (Optional) Override the requested API version in this + message. + + :returns: None. rpc.fanout_cast_to_server() does not wait on any + return values. + """ + self._set_version(msg, version) + rpc.fanout_cast_to_server(context, server_params, + self._get_topic(topic), msg) diff --git a/staccato/openstack/common/rpc/service.py b/staccato/openstack/common/rpc/service.py new file mode 100644 index 0000000..07837bd --- /dev/null +++ b/staccato/openstack/common/rpc/service.py @@ -0,0 +1,75 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# Copyright 2011 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from staccato.openstack.common.gettextutils import _ +from staccato.openstack.common import log as logging +from staccato.openstack.common import rpc +from staccato.openstack.common.rpc import dispatcher as rpc_dispatcher +from staccato.openstack.common import service + + +LOG = logging.getLogger(__name__) + + +class Service(service.Service): + """Service object for binaries running on hosts. + + A service enables rpc by listening to queues based on topic and host.""" + def __init__(self, host, topic, manager=None): + super(Service, self).__init__() + self.host = host + self.topic = topic + if manager is None: + self.manager = self + else: + self.manager = manager + + def start(self): + super(Service, self).start() + + self.conn = rpc.create_connection(new=True) + LOG.debug(_("Creating Consumer connection for Service %s") % + self.topic) + + dispatcher = rpc_dispatcher.RpcDispatcher([self.manager]) + + # Share this same connection for these Consumers + self.conn.create_consumer(self.topic, dispatcher, fanout=False) + + node_topic = '%s.%s' % (self.topic, self.host) + self.conn.create_consumer(node_topic, dispatcher, fanout=False) + + self.conn.create_consumer(self.topic, dispatcher, fanout=True) + + # Hook to allow the manager to do other initializations after + # the rpc connection is created. + if callable(getattr(self.manager, 'initialize_service_hook', None)): + self.manager.initialize_service_hook(self) + + # Consume from all consumers in a thread + self.conn.consume_in_thread() + + def stop(self): + # Try to shut the connection down, but if we get any sort of + # errors, go ahead and ignore them.. as we're shutting down anyway + try: + self.conn.close() + except Exception: + pass + super(Service, self).stop() diff --git a/staccato/openstack/common/rpc/zmq_receiver.py b/staccato/openstack/common/rpc/zmq_receiver.py new file mode 100755 index 0000000..30d2ed6 --- /dev/null +++ b/staccato/openstack/common/rpc/zmq_receiver.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import eventlet +eventlet.monkey_patch() + +import contextlib +import sys + +from oslo.config import cfg + +from staccato.openstack.common import log as logging +from staccato.openstack.common import rpc +from staccato.openstack.common.rpc import impl_zmq + +CONF = cfg.CONF +CONF.register_opts(rpc.rpc_opts) +CONF.register_opts(impl_zmq.zmq_opts) + + +def main(): + CONF(sys.argv[1:], project='oslo') + logging.setup("oslo") + + with contextlib.closing(impl_zmq.ZmqProxy(CONF)) as reactor: + reactor.consume_in_thread() + reactor.wait() diff --git a/staccato/openstack/common/service.py b/staccato/openstack/common/service.py new file mode 100644 index 0000000..69629d4 --- /dev/null +++ b/staccato/openstack/common/service.py @@ -0,0 +1,332 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# Copyright 2011 Justin Santa Barbara +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Generic Node base class for all workers that run on hosts.""" + +import errno +import os +import random +import signal +import sys +import time + +import eventlet +import logging as std_logging +from oslo.config import cfg + +from staccato.openstack.common import eventlet_backdoor +from staccato.openstack.common.gettextutils import _ +from staccato.openstack.common import importutils +from staccato.openstack.common import log as logging +from staccato.openstack.common import threadgroup + + +rpc = importutils.try_import('staccato.openstack.common.rpc') +CONF = cfg.CONF +LOG = logging.getLogger(__name__) + + +class Launcher(object): + """Launch one or more services and wait for them to complete.""" + + def __init__(self): + """Initialize the service launcher. + + :returns: None + + """ + self._services = threadgroup.ThreadGroup() + eventlet_backdoor.initialize_if_enabled() + + @staticmethod + def run_service(service): + """Start and wait for a service to finish. + + :param service: service to run and wait for. + :returns: None + + """ + service.start() + service.wait() + + def launch_service(self, service): + """Load and start the given service. + + :param service: The service you would like to start. + :returns: None + + """ + self._services.add_thread(self.run_service, service) + + def stop(self): + """Stop all services which are currently running. + + :returns: None + + """ + self._services.stop() + + def wait(self): + """Waits until all services have been stopped, and then returns. + + :returns: None + + """ + self._services.wait() + + +class SignalExit(SystemExit): + def __init__(self, signo, exccode=1): + super(SignalExit, self).__init__(exccode) + self.signo = signo + + +class ServiceLauncher(Launcher): + def _handle_signal(self, signo, frame): + # Allow the process to be killed again and die from natural causes + signal.signal(signal.SIGTERM, signal.SIG_DFL) + signal.signal(signal.SIGINT, signal.SIG_DFL) + + raise SignalExit(signo) + + def wait(self): + signal.signal(signal.SIGTERM, self._handle_signal) + signal.signal(signal.SIGINT, self._handle_signal) + + LOG.debug(_('Full set of CONF:')) + CONF.log_opt_values(LOG, std_logging.DEBUG) + + status = None + try: + super(ServiceLauncher, self).wait() + except SignalExit as exc: + signame = {signal.SIGTERM: 'SIGTERM', + signal.SIGINT: 'SIGINT'}[exc.signo] + LOG.info(_('Caught %s, exiting'), signame) + status = exc.code + except SystemExit as exc: + status = exc.code + finally: + if rpc: + rpc.cleanup() + self.stop() + return status + + +class ServiceWrapper(object): + def __init__(self, service, workers): + self.service = service + self.workers = workers + self.children = set() + self.forktimes = [] + + +class ProcessLauncher(object): + def __init__(self): + self.children = {} + self.sigcaught = None + self.running = True + rfd, self.writepipe = os.pipe() + self.readpipe = eventlet.greenio.GreenPipe(rfd, 'r') + + signal.signal(signal.SIGTERM, self._handle_signal) + signal.signal(signal.SIGINT, self._handle_signal) + + def _handle_signal(self, signo, frame): + self.sigcaught = signo + self.running = False + + # Allow the process to be killed again and die from natural causes + signal.signal(signal.SIGTERM, signal.SIG_DFL) + signal.signal(signal.SIGINT, signal.SIG_DFL) + + def _pipe_watcher(self): + # This will block until the write end is closed when the parent + # dies unexpectedly + self.readpipe.read() + + LOG.info(_('Parent process has died unexpectedly, exiting')) + + sys.exit(1) + + def _child_process(self, service): + # Setup child signal handlers differently + def _sigterm(*args): + signal.signal(signal.SIGTERM, signal.SIG_DFL) + raise SignalExit(signal.SIGTERM) + + signal.signal(signal.SIGTERM, _sigterm) + # Block SIGINT and let the parent send us a SIGTERM + signal.signal(signal.SIGINT, signal.SIG_IGN) + + # Reopen the eventlet hub to make sure we don't share an epoll + # fd with parent and/or siblings, which would be bad + eventlet.hubs.use_hub() + + # Close write to ensure only parent has it open + os.close(self.writepipe) + # Create greenthread to watch for parent to close pipe + eventlet.spawn_n(self._pipe_watcher) + + # Reseed random number generator + random.seed() + + launcher = Launcher() + launcher.run_service(service) + + def _start_child(self, wrap): + if len(wrap.forktimes) > wrap.workers: + # Limit ourselves to one process a second (over the period of + # number of workers * 1 second). This will allow workers to + # start up quickly but ensure we don't fork off children that + # die instantly too quickly. + if time.time() - wrap.forktimes[0] < wrap.workers: + LOG.info(_('Forking too fast, sleeping')) + time.sleep(1) + + wrap.forktimes.pop(0) + + wrap.forktimes.append(time.time()) + + pid = os.fork() + if pid == 0: + # NOTE(johannes): All exceptions are caught to ensure this + # doesn't fallback into the loop spawning children. It would + # be bad for a child to spawn more children. + status = 0 + try: + self._child_process(wrap.service) + except SignalExit as exc: + signame = {signal.SIGTERM: 'SIGTERM', + signal.SIGINT: 'SIGINT'}[exc.signo] + LOG.info(_('Caught %s, exiting'), signame) + status = exc.code + except SystemExit as exc: + status = exc.code + except BaseException: + LOG.exception(_('Unhandled exception')) + status = 2 + finally: + wrap.service.stop() + + os._exit(status) + + LOG.info(_('Started child %d'), pid) + + wrap.children.add(pid) + self.children[pid] = wrap + + return pid + + def launch_service(self, service, workers=1): + wrap = ServiceWrapper(service, workers) + + LOG.info(_('Starting %d workers'), wrap.workers) + while self.running and len(wrap.children) < wrap.workers: + self._start_child(wrap) + + def _wait_child(self): + try: + # Don't block if no child processes have exited + pid, status = os.waitpid(0, os.WNOHANG) + if not pid: + return None + except OSError as exc: + if exc.errno not in (errno.EINTR, errno.ECHILD): + raise + return None + + if os.WIFSIGNALED(status): + sig = os.WTERMSIG(status) + LOG.info(_('Child %(pid)d killed by signal %(sig)d'), + dict(pid=pid, sig=sig)) + else: + code = os.WEXITSTATUS(status) + LOG.info(_('Child %(pid)s exited with status %(code)d'), + dict(pid=pid, code=code)) + + if pid not in self.children: + LOG.warning(_('pid %d not in child list'), pid) + return None + + wrap = self.children.pop(pid) + wrap.children.remove(pid) + return wrap + + def wait(self): + """Loop waiting on children to die and respawning as necessary""" + + LOG.debug(_('Full set of CONF:')) + CONF.log_opt_values(LOG, std_logging.DEBUG) + + while self.running: + wrap = self._wait_child() + if not wrap: + # Yield to other threads if no children have exited + # Sleep for a short time to avoid excessive CPU usage + # (see bug #1095346) + eventlet.greenthread.sleep(.01) + continue + + while self.running and len(wrap.children) < wrap.workers: + self._start_child(wrap) + + if self.sigcaught: + signame = {signal.SIGTERM: 'SIGTERM', + signal.SIGINT: 'SIGINT'}[self.sigcaught] + LOG.info(_('Caught %s, stopping children'), signame) + + for pid in self.children: + try: + os.kill(pid, signal.SIGTERM) + except OSError as exc: + if exc.errno != errno.ESRCH: + raise + + # Wait for children to die + if self.children: + LOG.info(_('Waiting on %d children to exit'), len(self.children)) + while self.children: + self._wait_child() + + +class Service(object): + """Service object for binaries running on hosts.""" + + def __init__(self, threads=1000): + self.tg = threadgroup.ThreadGroup(threads) + + def start(self): + pass + + def stop(self): + self.tg.stop() + + def wait(self): + self.tg.wait() + + +def launch(service, workers=None): + if workers: + launcher = ProcessLauncher() + launcher.launch_service(service, workers=workers) + else: + launcher = ServiceLauncher() + launcher.launch_service(service) + return launcher diff --git a/staccato/openstack/common/setup.py b/staccato/openstack/common/setup.py new file mode 100644 index 0000000..03b0675 --- /dev/null +++ b/staccato/openstack/common/setup.py @@ -0,0 +1,369 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation. +# Copyright 2012-2013 Hewlett-Packard Development Company, L.P. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Utilities with minimum-depends for use in setup.py +""" + +from __future__ import print_function + +import email +import os +import re +import subprocess +import sys + +from setuptools.command import sdist + + +def parse_mailmap(mailmap='.mailmap'): + mapping = {} + if os.path.exists(mailmap): + with open(mailmap, 'r') as fp: + for l in fp: + try: + canonical_email, alias = re.match( + r'[^#]*?(<.+>).*(<.+>).*', l).groups() + except AttributeError: + continue + mapping[alias] = canonical_email + return mapping + + +def _parse_git_mailmap(git_dir, mailmap='.mailmap'): + mailmap = os.path.join(os.path.dirname(git_dir), mailmap) + return parse_mailmap(mailmap) + + +def canonicalize_emails(changelog, mapping): + """Takes in a string and an email alias mapping and replaces all + instances of the aliases in the string with their real email. + """ + for alias, email_address in mapping.iteritems(): + changelog = changelog.replace(alias, email_address) + return changelog + + +# Get requirements from the first file that exists +def get_reqs_from_files(requirements_files): + for requirements_file in requirements_files: + if os.path.exists(requirements_file): + with open(requirements_file, 'r') as fil: + return fil.read().split('\n') + return [] + + +def parse_requirements(requirements_files=['requirements.txt', + 'tools/pip-requires']): + requirements = [] + for line in get_reqs_from_files(requirements_files): + # For the requirements list, we need to inject only the portion + # after egg= so that distutils knows the package it's looking for + # such as: + # -e git://github.com/openstack/nova/master#egg=nova + if re.match(r'\s*-e\s+', line): + requirements.append(re.sub(r'\s*-e\s+.*#egg=(.*)$', r'\1', + line)) + # such as: + # http://github.com/openstack/nova/zipball/master#egg=nova + elif re.match(r'\s*https?:', line): + requirements.append(re.sub(r'\s*https?:.*#egg=(.*)$', r'\1', + line)) + # -f lines are for index locations, and don't get used here + elif re.match(r'\s*-f\s+', line): + pass + # argparse is part of the standard library starting with 2.7 + # adding it to the requirements list screws distro installs + elif line == 'argparse' and sys.version_info >= (2, 7): + pass + else: + requirements.append(line) + + return requirements + + +def parse_dependency_links(requirements_files=['requirements.txt', + 'tools/pip-requires']): + dependency_links = [] + # dependency_links inject alternate locations to find packages listed + # in requirements + for line in get_reqs_from_files(requirements_files): + # skip comments and blank lines + if re.match(r'(\s*#)|(\s*$)', line): + continue + # lines with -e or -f need the whole line, minus the flag + if re.match(r'\s*-[ef]\s+', line): + dependency_links.append(re.sub(r'\s*-[ef]\s+', '', line)) + # lines that are only urls can go in unmolested + elif re.match(r'\s*https?:', line): + dependency_links.append(line) + return dependency_links + + +def _run_shell_command(cmd, throw_on_error=False): + if os.name == 'nt': + output = subprocess.Popen(["cmd.exe", "/C", cmd], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + else: + output = subprocess.Popen(["/bin/sh", "-c", cmd], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + out = output.communicate() + if output.returncode and throw_on_error: + raise Exception("%s returned %d" % cmd, output.returncode) + if len(out) == 0: + return None + if len(out[0].strip()) == 0: + return None + return out[0].strip() + + +def _get_git_directory(): + parent_dir = os.path.dirname(__file__) + while True: + git_dir = os.path.join(parent_dir, '.git') + if os.path.exists(git_dir): + return git_dir + parent_dir, child = os.path.split(parent_dir) + if not child: # reached to root dir + return None + + +def write_git_changelog(): + """Write a changelog based on the git changelog.""" + new_changelog = 'ChangeLog' + git_dir = _get_git_directory() + if not os.getenv('SKIP_WRITE_GIT_CHANGELOG'): + if git_dir: + git_log_cmd = 'git --git-dir=%s log' % git_dir + changelog = _run_shell_command(git_log_cmd) + mailmap = _parse_git_mailmap(git_dir) + with open(new_changelog, "w") as changelog_file: + changelog_file.write(canonicalize_emails(changelog, mailmap)) + else: + open(new_changelog, 'w').close() + + +def generate_authors(): + """Create AUTHORS file using git commits.""" + jenkins_email = 'jenkins@review.(openstack|stackforge).org' + old_authors = 'AUTHORS.in' + new_authors = 'AUTHORS' + git_dir = _get_git_directory() + if not os.getenv('SKIP_GENERATE_AUTHORS'): + if git_dir: + # don't include jenkins email address in AUTHORS file + git_log_cmd = ("git --git-dir=" + git_dir + + " log --format='%aN <%aE>' | sort -u | " + "egrep -v '" + jenkins_email + "'") + changelog = _run_shell_command(git_log_cmd) + signed_cmd = ("git --git-dir=" + git_dir + + " log | grep -i Co-authored-by: | sort -u") + signed_entries = _run_shell_command(signed_cmd) + if signed_entries: + new_entries = "\n".join( + [signed.split(":", 1)[1].strip() + for signed in signed_entries.split("\n") if signed]) + changelog = "\n".join((changelog, new_entries)) + mailmap = _parse_git_mailmap(git_dir) + with open(new_authors, 'w') as new_authors_fh: + new_authors_fh.write(canonicalize_emails(changelog, mailmap)) + if os.path.exists(old_authors): + with open(old_authors, "r") as old_authors_fh: + new_authors_fh.write('\n' + old_authors_fh.read()) + else: + open(new_authors, 'w').close() + + +_rst_template = """%(heading)s +%(underline)s + +.. automodule:: %(module)s + :members: + :undoc-members: + :show-inheritance: +""" + + +def get_cmdclass(): + """Return dict of commands to run from setup.py.""" + + cmdclass = dict() + + def _find_modules(arg, dirname, files): + for filename in files: + if filename.endswith('.py') and filename != '__init__.py': + arg["%s.%s" % (dirname.replace('/', '.'), + filename[:-3])] = True + + class LocalSDist(sdist.sdist): + """Builds the ChangeLog and Authors files from VC first.""" + + def run(self): + write_git_changelog() + generate_authors() + # sdist.sdist is an old style class, can't use super() + sdist.sdist.run(self) + + cmdclass['sdist'] = LocalSDist + + # If Sphinx is installed on the box running setup.py, + # enable setup.py to build the documentation, otherwise, + # just ignore it + try: + from sphinx.setup_command import BuildDoc + + class LocalBuildDoc(BuildDoc): + + builders = ['html', 'man'] + + def generate_autoindex(self): + print("**Autodocumenting from %s" % os.path.abspath(os.curdir)) + modules = {} + option_dict = self.distribution.get_option_dict('build_sphinx') + source_dir = os.path.join(option_dict['source_dir'][1], 'api') + if not os.path.exists(source_dir): + os.makedirs(source_dir) + for pkg in self.distribution.packages: + if '.' not in pkg: + os.path.walk(pkg, _find_modules, modules) + module_list = modules.keys() + module_list.sort() + autoindex_filename = os.path.join(source_dir, 'autoindex.rst') + with open(autoindex_filename, 'w') as autoindex: + autoindex.write(""".. toctree:: + :maxdepth: 1 + +""") + for module in module_list: + output_filename = os.path.join(source_dir, + "%s.rst" % module) + heading = "The :mod:`%s` Module" % module + underline = "=" * len(heading) + values = dict(module=module, heading=heading, + underline=underline) + + print("Generating %s" % output_filename) + with open(output_filename, 'w') as output_file: + output_file.write(_rst_template % values) + autoindex.write(" %s.rst\n" % module) + + def run(self): + if not os.getenv('SPHINX_DEBUG'): + self.generate_autoindex() + + for builder in self.builders: + self.builder = builder + self.finalize_options() + self.project = self.distribution.get_name() + self.version = self.distribution.get_version() + self.release = self.distribution.get_version() + BuildDoc.run(self) + + class LocalBuildLatex(LocalBuildDoc): + builders = ['latex'] + + cmdclass['build_sphinx'] = LocalBuildDoc + cmdclass['build_sphinx_latex'] = LocalBuildLatex + except ImportError: + pass + + return cmdclass + + +def _get_revno(git_dir): + """Return the number of commits since the most recent tag. + + We use git-describe to find this out, but if there are no + tags then we fall back to counting commits since the beginning + of time. + """ + describe = _run_shell_command( + "git --git-dir=%s describe --always" % git_dir) + if "-" in describe: + return describe.rsplit("-", 2)[-2] + + # no tags found + revlist = _run_shell_command( + "git --git-dir=%s rev-list --abbrev-commit HEAD" % git_dir) + return len(revlist.splitlines()) + + +def _get_version_from_git(pre_version): + """Return a version which is equal to the tag that's on the current + revision if there is one, or tag plus number of additional revisions + if the current revision has no tag.""" + + git_dir = _get_git_directory() + if git_dir: + if pre_version: + try: + return _run_shell_command( + "git --git-dir=" + git_dir + " describe --exact-match", + throw_on_error=True).replace('-', '.') + except Exception: + sha = _run_shell_command( + "git --git-dir=" + git_dir + " log -n1 --pretty=format:%h") + return "%s.a%s.g%s" % (pre_version, _get_revno(git_dir), sha) + else: + return _run_shell_command( + "git --git-dir=" + git_dir + " describe --always").replace( + '-', '.') + return None + + +def _get_version_from_pkg_info(package_name): + """Get the version from PKG-INFO file if we can.""" + try: + pkg_info_file = open('PKG-INFO', 'r') + except (IOError, OSError): + return None + try: + pkg_info = email.message_from_file(pkg_info_file) + except email.MessageError: + return None + # Check to make sure we're in our own dir + if pkg_info.get('Name', None) != package_name: + return None + return pkg_info.get('Version', None) + + +def get_version(package_name, pre_version=None): + """Get the version of the project. First, try getting it from PKG-INFO, if + it exists. If it does, that means we're in a distribution tarball or that + install has happened. Otherwise, if there is no PKG-INFO file, pull the + version from git. + + We do not support setup.py version sanity in git archive tarballs, nor do + we support packagers directly sucking our git repo into theirs. We expect + that a source tarball be made from our git repo - or that if someone wants + to make a source tarball from a fork of our repo with additional tags in it + that they understand and desire the results of doing that. + """ + version = os.environ.get("OSLO_PACKAGE_VERSION", None) + if version: + return version + version = _get_version_from_pkg_info(package_name) + if version: + return version + version = _get_version_from_git(pre_version) + if version: + return version + raise Exception("Versioning for this project requires either an sdist" + " tarball, or access to an upstream git repository.") diff --git a/staccato/openstack/common/threadgroup.py b/staccato/openstack/common/threadgroup.py new file mode 100644 index 0000000..28de004 --- /dev/null +++ b/staccato/openstack/common/threadgroup.py @@ -0,0 +1,114 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from eventlet import greenlet +from eventlet import greenpool +from eventlet import greenthread + +from staccato.openstack.common import log as logging +from staccato.openstack.common import loopingcall + + +LOG = logging.getLogger(__name__) + + +def _thread_done(gt, *args, **kwargs): + """ Callback function to be passed to GreenThread.link() when we spawn() + Calls the :class:`ThreadGroup` to notify if. + + """ + kwargs['group'].thread_done(kwargs['thread']) + + +class Thread(object): + """ Wrapper around a greenthread, that holds a reference to the + :class:`ThreadGroup`. The Thread will notify the :class:`ThreadGroup` when + it has done so it can be removed from the threads list. + """ + def __init__(self, thread, group): + self.thread = thread + self.thread.link(_thread_done, group=group, thread=self) + + def stop(self): + self.thread.kill() + + def wait(self): + return self.thread.wait() + + +class ThreadGroup(object): + """ The point of the ThreadGroup classis to: + + * keep track of timers and greenthreads (making it easier to stop them + when need be). + * provide an easy API to add timers. + """ + def __init__(self, thread_pool_size=10): + self.pool = greenpool.GreenPool(thread_pool_size) + self.threads = [] + self.timers = [] + + def add_timer(self, interval, callback, initial_delay=None, + *args, **kwargs): + pulse = loopingcall.FixedIntervalLoopingCall(callback, *args, **kwargs) + pulse.start(interval=interval, + initial_delay=initial_delay) + self.timers.append(pulse) + + def add_thread(self, callback, *args, **kwargs): + gt = self.pool.spawn(callback, *args, **kwargs) + th = Thread(gt, self) + self.threads.append(th) + + def thread_done(self, thread): + self.threads.remove(thread) + + def stop(self): + current = greenthread.getcurrent() + for x in self.threads: + if x is current: + # don't kill the current thread. + continue + try: + x.stop() + except Exception as ex: + LOG.exception(ex) + + for x in self.timers: + try: + x.stop() + except Exception as ex: + LOG.exception(ex) + self.timers = [] + + def wait(self): + for x in self.timers: + try: + x.wait() + except greenlet.GreenletExit: + pass + except Exception as ex: + LOG.exception(ex) + current = greenthread.getcurrent() + for x in self.threads: + if x is current: + continue + try: + x.wait() + except greenlet.GreenletExit: + pass + except Exception as ex: + LOG.exception(ex) diff --git a/staccato/openstack/common/timeutils.py b/staccato/openstack/common/timeutils.py new file mode 100644 index 0000000..6094365 --- /dev/null +++ b/staccato/openstack/common/timeutils.py @@ -0,0 +1,186 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Time related utilities and helper functions. +""" + +import calendar +import datetime + +import iso8601 + + +# ISO 8601 extended time format with microseconds +_ISO8601_TIME_FORMAT_SUBSECOND = '%Y-%m-%dT%H:%M:%S.%f' +_ISO8601_TIME_FORMAT = '%Y-%m-%dT%H:%M:%S' +PERFECT_TIME_FORMAT = _ISO8601_TIME_FORMAT_SUBSECOND + + +def isotime(at=None, subsecond=False): + """Stringify time in ISO 8601 format""" + if not at: + at = utcnow() + st = at.strftime(_ISO8601_TIME_FORMAT + if not subsecond + else _ISO8601_TIME_FORMAT_SUBSECOND) + tz = at.tzinfo.tzname(None) if at.tzinfo else 'UTC' + st += ('Z' if tz == 'UTC' else tz) + return st + + +def parse_isotime(timestr): + """Parse time from ISO 8601 format""" + try: + return iso8601.parse_date(timestr) + except iso8601.ParseError as e: + raise ValueError(e.message) + except TypeError as e: + raise ValueError(e.message) + + +def strtime(at=None, fmt=PERFECT_TIME_FORMAT): + """Returns formatted utcnow.""" + if not at: + at = utcnow() + return at.strftime(fmt) + + +def parse_strtime(timestr, fmt=PERFECT_TIME_FORMAT): + """Turn a formatted time back into a datetime.""" + return datetime.datetime.strptime(timestr, fmt) + + +def normalize_time(timestamp): + """Normalize time in arbitrary timezone to UTC naive object""" + offset = timestamp.utcoffset() + if offset is None: + return timestamp + return timestamp.replace(tzinfo=None) - offset + + +def is_older_than(before, seconds): + """Return True if before is older than seconds.""" + if isinstance(before, basestring): + before = parse_strtime(before).replace(tzinfo=None) + return utcnow() - before > datetime.timedelta(seconds=seconds) + + +def is_newer_than(after, seconds): + """Return True if after is newer than seconds.""" + if isinstance(after, basestring): + after = parse_strtime(after).replace(tzinfo=None) + return after - utcnow() > datetime.timedelta(seconds=seconds) + + +def utcnow_ts(): + """Timestamp version of our utcnow function.""" + return calendar.timegm(utcnow().timetuple()) + + +def utcnow(): + """Overridable version of utils.utcnow.""" + if utcnow.override_time: + try: + return utcnow.override_time.pop(0) + except AttributeError: + return utcnow.override_time + return datetime.datetime.utcnow() + + +def iso8601_from_timestamp(timestamp): + """Returns a iso8601 formated date from timestamp""" + return isotime(datetime.datetime.utcfromtimestamp(timestamp)) + + +utcnow.override_time = None + + +def set_time_override(override_time=datetime.datetime.utcnow()): + """ + Override utils.utcnow to return a constant time or a list thereof, + one at a time. + """ + utcnow.override_time = override_time + + +def advance_time_delta(timedelta): + """Advance overridden time using a datetime.timedelta.""" + assert(not utcnow.override_time is None) + try: + for dt in utcnow.override_time: + dt += timedelta + except TypeError: + utcnow.override_time += timedelta + + +def advance_time_seconds(seconds): + """Advance overridden time by seconds.""" + advance_time_delta(datetime.timedelta(0, seconds)) + + +def clear_time_override(): + """Remove the overridden time.""" + utcnow.override_time = None + + +def marshall_now(now=None): + """Make an rpc-safe datetime with microseconds. + + Note: tzinfo is stripped, but not required for relative times.""" + if not now: + now = utcnow() + return dict(day=now.day, month=now.month, year=now.year, hour=now.hour, + minute=now.minute, second=now.second, + microsecond=now.microsecond) + + +def unmarshall_time(tyme): + """Unmarshall a datetime dict.""" + return datetime.datetime(day=tyme['day'], + month=tyme['month'], + year=tyme['year'], + hour=tyme['hour'], + minute=tyme['minute'], + second=tyme['second'], + microsecond=tyme['microsecond']) + + +def delta_seconds(before, after): + """ + Compute the difference in seconds between two date, time, or + datetime objects (as a float, to microsecond resolution). + """ + delta = after - before + try: + return delta.total_seconds() + except AttributeError: + return ((delta.days * 24 * 3600) + delta.seconds + + float(delta.microseconds) / (10 ** 6)) + + +def is_soon(dt, window): + """ + Determines if time is going to happen in the next window seconds. + + :params dt: the time + :params window: minimum seconds to remain to consider the time not soon + + :return: True if expiration is within the given duration + """ + soon = (utcnow() + datetime.timedelta(seconds=window)) + return normalize_time(dt) <= soon diff --git a/staccato/openstack/common/uuidutils.py b/staccato/openstack/common/uuidutils.py new file mode 100644 index 0000000..7608acb --- /dev/null +++ b/staccato/openstack/common/uuidutils.py @@ -0,0 +1,39 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright (c) 2012 Intel Corporation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +UUID related utilities and helper functions. +""" + +import uuid + + +def generate_uuid(): + return str(uuid.uuid4()) + + +def is_uuid_like(val): + """Returns validation of a value as a UUID. + + For our purposes, a UUID is a canonical form string: + aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa + + """ + try: + return str(uuid.UUID(val)) == val + except (TypeError, ValueError, AttributeError): + return False diff --git a/staccato/openstack/common/version.py b/staccato/openstack/common/version.py new file mode 100644 index 0000000..e382497 --- /dev/null +++ b/staccato/openstack/common/version.py @@ -0,0 +1,94 @@ + +# Copyright 2012 OpenStack Foundation +# Copyright 2012-2013 Hewlett-Packard Development Company, L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Utilities for consuming the version from pkg_resources. +""" + +import pkg_resources + + +class VersionInfo(object): + + def __init__(self, package): + """Object that understands versioning for a package + :param package: name of the python package, such as glance, or + python-glanceclient + """ + self.package = package + self.release = None + self.version = None + self._cached_version = None + + def __str__(self): + """Make the VersionInfo object behave like a string.""" + return self.version_string() + + def __repr__(self): + """Include the name.""" + return "VersionInfo(%s:%s)" % (self.package, self.version_string()) + + def _get_version_from_pkg_resources(self): + """Get the version of the package from the pkg_resources record + associated with the package.""" + try: + requirement = pkg_resources.Requirement.parse(self.package) + provider = pkg_resources.get_provider(requirement) + return provider.version + except pkg_resources.DistributionNotFound: + # The most likely cause for this is running tests in a tree + # produced from a tarball where the package itself has not been + # installed into anything. Revert to setup-time logic. + from staccato.openstack.common import setup + return setup.get_version(self.package) + + def release_string(self): + """Return the full version of the package including suffixes indicating + VCS status. + """ + if self.release is None: + self.release = self._get_version_from_pkg_resources() + + return self.release + + def version_string(self): + """Return the short version minus any alpha/beta tags.""" + if self.version is None: + parts = [] + for part in self.release_string().split('.'): + if part[0].isdigit(): + parts.append(part) + else: + break + self.version = ".".join(parts) + + return self.version + + # Compatibility functions + canonical_version_string = version_string + version_string_with_vcs = release_string + + def cached_version_string(self, prefix=""): + """Generate an object which will expand in a string context to + the results of version_string(). We do this so that don't + call into pkg_resources every time we start up a program when + passing version information into the CONF constructor, but + rather only do the calculation when and if a version is requested + """ + if not self._cached_version: + self._cached_version = "%s%s" % (prefix, + self.version_string()) + return self._cached_version diff --git a/staccato/protocols/__init__.py b/staccato/protocols/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/staccato/protocols/file/__init__.py b/staccato/protocols/file/__init__.py new file mode 100644 index 0000000..99a789b --- /dev/null +++ b/staccato/protocols/file/__init__.py @@ -0,0 +1,119 @@ +import staccato.protocols.interface as base +from staccato.common import exceptions + + +class FileProtocol(base.BaseProtocolInterface): + + def __init__(self, service_config): + self.conf = service_config + + def _validate_url(self, url_parts): + pass + + def new_write(self, dsturl_parts, dst_opts): + return {} + + def new_read(self, srcurl_parts, src_opts): + return + + def get_reader(self, url_parts, writer, monitor, start=0, + end=None, **kwvals): + self._validate_url(url_parts) + + return FileReadConnection(url_parts.path, + writer, + monitor, + start=start, + end=end, + buflen=65536, + **kwvals) + + def get_writer(self, url_parts, checkpointer, **kwvals): + self._validate_url(url_parts) + + return FileWriteConnection(url_parts.path, checkpointer=checkpointer, + **kwvals) + + +class FileReadConnection(base.BaseReadConnection): + + def __init__(self, + path, + writer, + monitor, + start=0, + end=None, + buflen=65536, + **kwvals): + + try: + self.fptr = open(path, 'rb') + except IOError, ioe: + raise exceptions.StaccatoProtocolConnectionException( + ioe.message) + self.pos = start + self.eof = False + self.writer = writer + self.path = path + self.buflen = buflen + self.end = end + self.monitor = monitor + + def _read(self, buflen): + current_pos = self.fptr.tell() + if current_pos != self.pos: + self.fptr.seek(self.pos) + + if self.end and self.pos + buflen > self.end: + buflen = self.end - self.pos + buf = self.fptr.read(buflen) + if not buf: + return True, 0 + self.writer.write(buf, self.pos) + + self.pos = self.pos + len(buf) + if self.end and self.pos >= self.end: + return True, len(buf) + if len(buf) < buflen: + return True, len(buf) + return False, len(buf) + + def process(self): + if isinstance(self.writer, FileWriteConnection): + # TODO here we can do a system copy optimization + pass + + try: + while not self.monitor.is_done() and not self.eof: + self.eof, read_len = self._read(self.buflen) + finally: + self.fptr.close() + + +class FileWriteConnection(base.BaseWriteConnection): + + def __init__(self, path, checkpointer=None, **kwvals): + self.count = 0 + self.persist = checkpointer + try: + self.fptr = open(path, 'wb') + except IOError, ioe: + raise exceptions.StaccatoProtocolConnectionException( + ioe.message) + + def write(self, buffer, offset): + self.fptr.seek(offset, 0) + rc = self.fptr.write(buffer) + if self.persist: + self.persist.update(offset, offset + len(buffer)) + self.count = self.count + 1 + if self.count > 10: + self.count = 0 + self.fptr.flush() + if self.persist: + self.persist.sync({}) + + return rc + + def close(self): + return self.fptr.close() diff --git a/staccato/protocols/interface.py b/staccato/protocols/interface.py new file mode 100644 index 0000000..0f4f835 --- /dev/null +++ b/staccato/protocols/interface.py @@ -0,0 +1,39 @@ +from staccato.common import utils + + +class BaseProtocolInterface(object): + + @utils.not_implemented_decorator + def get_reader(self, url_parts, writer, monitor, start=0, end=None, + **kwvals): + pass + + @utils.not_implemented_decorator + def get_writer(self, url_parts, checkpointer, **kwvals): + pass + + @utils.not_implemented_decorator + def new_write(self, dsturl_parts, dst_opts): + pass + + @utils.not_implemented_decorator + def new_read(self, srcurl_parts, src_opts): + pass + + +class BaseReadConnection(object): + + @utils.not_implemented_decorator + def process(self): + pass + + +class BaseWriteConnection(object): + + @utils.not_implemented_decorator + def write(self, buffer, offset): + pass + + @utils.not_implemented_decorator + def close(self): + pass diff --git a/staccato/scheduler/__init__.py b/staccato/scheduler/__init__.py new file mode 100644 index 0000000..7fcc62c --- /dev/null +++ b/staccato/scheduler/__init__.py @@ -0,0 +1 @@ +__author__ = 'jbresnah' diff --git a/staccato/scheduler/simple_thread.py b/staccato/scheduler/simple_thread.py new file mode 100644 index 0000000..d12bf3a --- /dev/null +++ b/staccato/scheduler/simple_thread.py @@ -0,0 +1,24 @@ + + +class SimpleCountSchedler(object): + + def __init__(self, db_obj, max_at_once=4): + self.max_at_once = max_at_once + self.db_obj = db_obj + self.running = 0 + + def _new_transfer(self, request): + self.running += 1 + # todo start the transfer + + def _transfer_complete(self): + self.running -= 1 + + def _check_for_transfers(self): + avail = self.max_at_once - self.running + xfer_request_ready = self.db_obj.get_all_ready(limit=avail) + for request in xfer_request_ready: + self._new_transfer(request) + + def poll(self): + self._check_for_transfers() diff --git a/staccato/tests/__init__.py b/staccato/tests/__init__.py new file mode 100644 index 0000000..7fcc62c --- /dev/null +++ b/staccato/tests/__init__.py @@ -0,0 +1 @@ +__author__ = 'jbresnah' diff --git a/staccato/tests/functional/__init__.py b/staccato/tests/functional/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/staccato/tests/functional/test_db.py b/staccato/tests/functional/test_db.py new file mode 100644 index 0000000..8b97992 --- /dev/null +++ b/staccato/tests/functional/test_db.py @@ -0,0 +1,57 @@ +import os +from staccato.tests import utils +from staccato.common import config + + +class TestDB(utils.TempFileCleanupBaseTest): + + def setUp(self): + super(TestDB, self).setUp() + + self.tmp_db = self.get_tempfile() + self.db_url = 'sqlite:///%s' % (self.tmp_db) + + conf_d = {'sql_connection': self.db_url, + 'protocol_policy': ''} + + self.conf_file = self.make_confile(conf_d) + self.conf = config.get_config_object( + args=[], + default_config_files=[self.conf_file]) + self.db = self.make_db(self.conf) + + def test_db_creation(self): + self.assertTrue(os.path.exists(self.tmp_db)) + + def test_db_new_xfer(self): + src = "src://url" + dst = "dst://url" + sm = "src.module" + dm = "dst.module" + xfer = self.db.get_new_xfer(src, dst, sm, dm) + self.assertEqual(src, xfer.srcurl) + self.assertEqual(dst, xfer.dsturl) + self.assertEqual(sm, xfer.src_module_name) + self.assertEqual(dm, xfer.dst_module_name) + + def test_db_xfer_lookup(self): + src = "src://url" + dst = "dst://url" + sm = "src.module" + dm = "dst.module" + + xfer1 = self.db.get_new_xfer(src, dst, sm, dm) + xfer2 = self.db.lookup_xfer_request_by_id(xfer1.id) + self.assertEqual(xfer1.id, xfer2.id) + + def test_db_xfer_update(self): + src = "src://url" + dst = "dst://url" + sm = "src.module" + dm = "dst.module" + + xfer1 = self.db.get_new_xfer(src, dst, sm, dm) + xfer1.next_ndx = 10 + self.db.save_db_obj(xfer1) + xfer2 = self.db.lookup_xfer_request_by_id(xfer1.id) + self.assertEqual(xfer2.next_ndx, 10) diff --git a/staccato/tests/functional/test_xfer.py b/staccato/tests/functional/test_xfer.py new file mode 100644 index 0000000..8c23662 --- /dev/null +++ b/staccato/tests/functional/test_xfer.py @@ -0,0 +1,68 @@ +import filecmp +import time + +import staccato.xfer.interface as xfer_iface +import staccato.xfer.constants as xfer_consts +import staccato.db as db +import staccato.xfer.constants as constants +from staccato.common import config +from staccato.tests import utils + + +class FakeStateMachine(object): + + def event_occurred(self, *args, **kwvals): + pass + + +class TestXfer(utils.TempFileCleanupBaseTest): + + def setUp(self): + super(TestXfer, self).setUp() + self.conf = config.get_config_object(args=[]) + self.tmp_db = self.get_tempfile() + self.db_url = 'sqlite:///%s' % (self.tmp_db) + + conf_d = {'sql_connection': self.db_url, + 'protocol_policy': ''} + + self.conf_file = self.make_confile(conf_d, utils.FILE_ONLY_PROTOCOL) + self.conf = config.get_config_object( + args=[], + default_config_files=[self.conf_file]) + + def test_file_xfer_basic(self): + dst_file = self.get_tempfile() + src_file = "/bin/bash" + src_url = "file://%s" % src_file + dst_url = "file://%s" % dst_file + + xfer = xfer_iface.xfer_new(self.conf, src_url, dst_url, + {}, {}, 0, None) + xfer_iface.xfer_start(self.conf, xfer.id) + + db_obj = db.StaccatoDB(self.conf) + while not xfer_consts.is_state_done_running(xfer.state): + time.sleep(0.01) + xfer = db_obj.lookup_xfer_request_by_id(xfer.id) + + self.assertTrue(filecmp.cmp(dst_file, src_file)) + self.assertTrue(xfer.state, constants.States.STATE_COMPLETE) + + def test_file_xfer_cancel(self): + dst_file = self.get_tempfile() + src_file = "/dev/zero" + src_url = "file://%s" % src_file + dst_url = "file://%s" % dst_file + + xfer = xfer_iface.xfer_new(self.conf, src_url, dst_url, + {}, {}, 0, None) + xfer_iface.xfer_start(self.conf, xfer.id) + xfer_iface.xfer_cancel(self.conf, xfer.id) + + db_obj = db.StaccatoDB(self.conf) + while not xfer_consts.is_state_done_running(xfer.state): + time.sleep(0.01) + xfer = db_obj.lookup_xfer_request_by_id(xfer.id) + + self.assertTrue(xfer.state, constants.States.STATE_CANCELED) diff --git a/staccato/tests/unit/test_config.py b/staccato/tests/unit/test_config.py new file mode 100644 index 0000000..bedbacf --- /dev/null +++ b/staccato/tests/unit/test_config.py @@ -0,0 +1,10 @@ +import testtools + +from staccato.common import config + + +class TestConfig(testtools.TestCase): + + def test_db_connection_default(self): + conf = config.get_config_object(args=[]) + self.assertEquals(conf.sql_connection, 'sqlite:///staccato.sqlite') diff --git a/staccato/tests/unit/test_protocol_loading.py b/staccato/tests/unit/test_protocol_loading.py new file mode 100644 index 0000000..7056851 --- /dev/null +++ b/staccato/tests/unit/test_protocol_loading.py @@ -0,0 +1,102 @@ +import urlparse +import testtools +from staccato.common import utils, exceptions +import staccato.protocols.file as file_protocol + + +class TestProtocolLoading(testtools.TestCase): + + def setUp(self): + super(TestProtocolLoading, self).setUp() + + def test_basic_load(self): + proto_name = "staccato.protocols.file.FileProtocol" + inst = utils.load_protocol_module(proto_name, {}) + self.assertTrue(isinstance(inst, file_protocol.FileProtocol)) + + def test_failed_load(self): + self.assertRaises(exceptions.StaccatoParameterError, + utils.load_protocol_module, + "notAModule", {}) + + def test_find_module_default(self): + url = "file://host.com:90//path/to/file" + url_parts = urlparse.urlparse(url) + module_path = "just.some.thing" + + lookup_dict = { + 'file': [{'module': module_path}] + } + res = utils.find_protocol_module_name(lookup_dict, url_parts) + self.assertEqual(res, module_path) + + def test_find_module_wildcards(self): + url = "file://host.com:90//path/to/file" + url_parts = urlparse.urlparse(url) + module_path = "just.some.thing" + + lookup_dict = { + 'file': [{'module': module_path, + 'netloc': '.*', + 'path': '.*'}] + } + res = utils.find_protocol_module_name(lookup_dict, url_parts) + self.assertEqual(res, module_path) + + def test_find_module_multiple_wildcards(self): + url = "file://host.com:90//path/to/file" + url_parts = urlparse.urlparse(url) + bad_module_path = "just.some.bad.thing" + good_module_path = "just.some.bad.thing" + + lookup_dict = { + 'file': [{'module': bad_module_path, + 'netloc': '.*', + 'path': '/sorry/.*'}, + {'module': good_module_path}] + } + res = utils.find_protocol_module_name(lookup_dict, url_parts) + self.assertEqual(res, good_module_path) + + def test_find_module_wildcards_middle(self): + url = "file://host.com:90//path/to/file" + url_parts = urlparse.urlparse(url) + module_path = "just.some.thing" + + lookup_dict = { + 'file': [{'module': module_path, + 'netloc': '.*host.com.*', + 'path': '.*'}] + } + res = utils.find_protocol_module_name(lookup_dict, url_parts) + self.assertEqual(res, module_path) + + def test_find_module_not_found(self): + url = "file://host.com:90//path/to/file" + url_parts = urlparse.urlparse(url) + module_path = "just.some.thing" + + lookup_dict = { + 'file': [{'module': module_path, + 'netloc': '.*', + 'path': '/secure/path/only.*'}] + } + self.assertRaises(exceptions.StaccatoParameterError, + utils.find_protocol_module_name, + lookup_dict, + url_parts) + + def test_find_no_url_scheme(self): + url = "file://host.com:90//path/to/file" + url_parts = urlparse.urlparse(url) + module_path = "just.some.thing" + + lookup_dict = { + 'junk': [{'module': module_path, + 'netloc': '.*', + 'path': '/secure/path/only.*'}] + } + self.assertRaises(exceptions.StaccatoParameterError, + utils.find_protocol_module_name, + lookup_dict, + url_parts) diff --git a/staccato/tests/unit/test_utils.py b/staccato/tests/unit/test_utils.py new file mode 100644 index 0000000..3653aa4 --- /dev/null +++ b/staccato/tests/unit/test_utils.py @@ -0,0 +1,68 @@ +import testtools + +import staccato.xfer.events as xfers + + +class FakeXferRequest(object): + next_ndx = 0 + + +class FakeDB(object): + + def save_db_obj(self, obj): + pass + + +class TestXferCheckpointerSingleSync(testtools.TestCase): + + def setUp(self): + super(TestXferCheckpointerSingleSync, self).setUp() + + self.fake_xfer = FakeXferRequest() + self.checker = xfers.XferCheckpointer(self.fake_xfer, {}, + FakeDB(), db_refresh_rate=0) + + def _run_blocks(self, blocks): + for start, end in blocks: + self.checker.update(start, end) + self.checker.sync({}) + + def test_non_continuous_zero(self): + blocks = [(10, 20), (30, 40)] + self._run_blocks(blocks) + self.assertEqual(self.fake_xfer.next_ndx, 0) + + def test_join_simple(self): + blocks = [(0, 20), (20, 40)] + self._run_blocks(blocks) + self.assertEqual(self.fake_xfer.next_ndx, 40) + + def test_non_continuous(self): + blocks = [(0, 5), (10, 20)] + self._run_blocks(blocks) + self.assertEqual(self.fake_xfer.next_ndx, 5) + + def test_join_single(self): + blocks = [(0, 20)] + self._run_blocks(blocks) + self.assertEqual(self.fake_xfer.next_ndx, 20) + + def test_join_overlap(self): + blocks = [(0, 20), (10, 30)] + self._run_blocks(blocks) + self.assertEqual(self.fake_xfer.next_ndx, 30) + + def test_join_included(self): + blocks = [(0, 20), (10, 15)] + self._run_blocks(blocks) + self.assertEqual(self.fake_xfer.next_ndx, 20) + + def test_join_large_later(self): + blocks = [(10, 20), (30, 40), (0, 100)] + self._run_blocks(blocks) + self.assertEqual(self.fake_xfer.next_ndx, 100) + + def test_join_out_of_order(self): + blocks = [(30, 40), (0, 10), (20, 30), (10, 25)] + self._run_blocks(blocks) + self.assertEqual(self.fake_xfer.next_ndx, 40) diff --git a/staccato/tests/utils.py b/staccato/tests/utils.py new file mode 100644 index 0000000..90b9bc2 --- /dev/null +++ b/staccato/tests/utils.py @@ -0,0 +1,57 @@ +import tempfile +import testtools +import staccato.db as db +import os +import json + +TEST_CONF = """ +[DEFAULT] + +sql_connection = %(sql_connection)s +db_auto_create = True +log_level = DEBUG +protocol_policy = %(protocol_policy)s +""" + +FILE_ONLY_PROTOCOL = { + "file": [{"module": "staccato.protocols.file.FileProtocol"}] +} + + +class TempFileCleanupBaseTest(testtools.TestCase): + + def setUp(self): + super(TempFileCleanupBaseTest, self).setUp() + self.files_to_delete = [] + + def make_db(self, conf): + return db.StaccatoDB(conf) + + def make_confile(self, d, protocol_policy=None): + conf_file = self.get_tempfile() + + if protocol_policy is not None: + protocol_policy_file = self.get_tempfile() + f = open(protocol_policy_file, 'w') + json.dump(protocol_policy, f) + f.close() + d.update({'protocol_policy': protocol_policy_file}) + + out_conf = TEST_CONF % d + fout = open(conf_file, 'w') + fout.write(out_conf) + fout.close() + + return conf_file + + def tearDown(self): + super(TempFileCleanupBaseTest, self).tearDown() + for f in self.files_to_delete: + try: + pass + os.remove(f) + except: + pass + + def get_tempfile(self): + return tempfile.mkstemp()[1] diff --git a/staccato/version.py b/staccato/version.py new file mode 100644 index 0000000..47628a8 --- /dev/null +++ b/staccato/version.py @@ -0,0 +1,20 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012 OpenStack Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +from staccato.openstack.common import version as common_version + +version_info = common_version.VersionInfo('staccato') diff --git a/staccato/xfer/__init__.py b/staccato/xfer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/staccato/xfer/constants.py b/staccato/xfer/constants.py new file mode 100644 index 0000000..3a778d2 --- /dev/null +++ b/staccato/xfer/constants.py @@ -0,0 +1,28 @@ +class Events: + EVENT_NEW = "EVENT_NEW" + EVENT_START = "EVENT_START" + EVENT_STARTED = "EVENT_STARTED" + EVENT_ERROR = "EVENT_ERROR" + EVENT_COMPLETE = "EVENT_COMPLETE" + EVENT_CANCEL = "EVENT_CANCEL" + EVENT_DELETE = "EVENT_DELETE" + + +class States: + STATE_NEW = "STATE_NEW" + STATE_STARTING = "STATE_STARTING" + STATE_RUNNING = "STATE_RUNNING" + STATE_CANCELING = "STATE_CANCELING" + STATE_CANCELED = "STATE_CANCELED" + STATE_ERRORING = "STATE_ERRORING" + STATE_ERROR = "STATE_ERROR" + STATE_COMPLETE = "STATE_COMPLETE" + STATE_DELETED = "STATE_DELETED" + + +def is_state_done_running(state): + done_states = [States.STATE_CANCELED, + States.STATE_ERROR, + States.STATE_COMPLETE, + States.STATE_DELETED] + return state in done_states diff --git a/staccato/xfer/events.py b/staccato/xfer/events.py new file mode 100644 index 0000000..5fc2a6a --- /dev/null +++ b/staccato/xfer/events.py @@ -0,0 +1,143 @@ +""" +This file describes events that can happen on a request structure +""" +from staccato.common import state_machine +from staccato.xfer import constants +from staccato.xfer import executor + + +class XferStateMachine(state_machine.StateMachine): + + def _state_changed(self, current_state, event, new_state, **kwvals): + xfer_request = kwvals['xfer_request'] + db = kwvals['db'] + xfer_request.state = new_state + db.save_db_obj(xfer_request) + + def _get_current_state(self, **kwvals): + xfer_request = kwvals['xfer_request'] + db = kwvals['db'] + xfer_request = db.lookup_xfer_request_by_id(xfer_request.id) + return xfer_request.state + +g_my_states = XferStateMachine() + + +def state_noop_handler( + current_state, + event, + new_state, + conf, + db, + xfer_request, + **kwvals): + """ + This handler just allows for the DB change. + """ + + +def state_starting_handler( + current_state, + event, + new_state, + conf, + db, + xfer_request, + **kwvals): + g_my_states.event_occurred(constants.Events.EVENT_STARTED, + conf=conf, + xfer_request=xfer_request, + db=db) + + +def state_running_handler( + current_state, + event, + new_state, + conf, + db, + xfer_request, + **kwvals): + executor.SimpleThreadExecutor(xfer_request.id, conf, g_my_states) + + +def state_delete_handler( + current_state, + event, + new_state, + conf, + db, + xfer_request, + **kwvals): + db.delete_db_obj(xfer_request) + + +g_my_states.set_state_func(constants.States.STATE_NEW, + state_noop_handler) +g_my_states.set_state_func(constants.States.STATE_STARTING, + state_starting_handler) +g_my_states.set_state_func(constants.States.STATE_RUNNING, + state_running_handler) +g_my_states.set_state_func(constants.States.STATE_CANCELING, + state_noop_handler) +g_my_states.set_state_func(constants.States.STATE_CANCELED, + state_noop_handler) +g_my_states.set_state_func(constants.States.STATE_ERRORING, + state_noop_handler) +g_my_states.set_state_func(constants.States.STATE_ERROR, + state_noop_handler) +g_my_states.set_state_func(constants.States.STATE_COMPLETE, + state_noop_handler) +g_my_states.set_state_func(constants.States.STATE_DELETED, + state_delete_handler) + +# setup the state machine +g_my_states.set_mapping(constants.States.STATE_NEW, + constants.Events.EVENT_START, + constants.States.STATE_STARTING) +g_my_states.set_mapping(constants.States.STATE_NEW, + constants.Events.EVENT_CANCEL, + constants.States.STATE_CANCELED) +g_my_states.set_mapping(constants.States.STATE_NEW, + constants.Events.EVENT_START, + constants.States.STATE_STARTING) + +g_my_states.set_mapping(constants.States.STATE_CANCELED, + constants.Events.EVENT_DELETE, + constants.States.STATE_DELETED) + +g_my_states.set_mapping(constants.States.STATE_STARTING, + constants.Events.EVENT_STARTED, + constants.States.STATE_RUNNING) +g_my_states.set_mapping(constants.States.STATE_STARTING, + constants.Events.EVENT_CANCEL, + constants.States.STATE_CANCELING) +g_my_states.set_mapping(constants.States.STATE_STARTING, + constants.Events.EVENT_ERROR, + constants.States.STATE_ERRORING) + +g_my_states.set_mapping(constants.States.STATE_RUNNING, + constants.Events.EVENT_COMPLETE, + constants.States.STATE_COMPLETE) +g_my_states.set_mapping(constants.States.STATE_RUNNING, + constants.Events.EVENT_CANCEL, + constants.States.STATE_CANCELING) +g_my_states.set_mapping(constants.States.STATE_RUNNING, + constants.Events.EVENT_ERROR, + constants.States.STATE_ERRORING) + +g_my_states.set_mapping(constants.States.STATE_ERRORING, + constants.Events.EVENT_COMPLETE, + constants.States.STATE_ERROR) + +g_my_states.set_mapping(constants.States.STATE_COMPLETE, + constants.Events.EVENT_DELETE, + constants.States.STATE_DELETED) + +g_my_states.set_mapping(constants.States.STATE_ERROR, + constants.Events.EVENT_START, + constants.States.STATE_STARTING) + + +def _print_state_machine(): + g_my_states.mapping_to_digraph() diff --git a/staccato/xfer/executor.py b/staccato/xfer/executor.py new file mode 100644 index 0000000..0677e28 --- /dev/null +++ b/staccato/xfer/executor.py @@ -0,0 +1,62 @@ +import threading +import urlparse + +from staccato import db +from staccato.common import utils +from staccato.xfer import constants +import staccato.xfer.utils as xfer_utils + + +def do_transfer(CONF, xfer_id, state_machine): + """ + This function does a transfer. It will create its own DB. This should be + run in its own thread. + """ + db_con = db.StaccatoDB(CONF) + try: + request = db_con.lookup_xfer_request_by_id(xfer_id) + + checkpointer = xfer_utils.XferCheckpointer(request, {}, db_con) + monitor = xfer_utils.XferReadMonitor(db_con, request.id) + + src_module = utils.load_protocol_module(request.src_module_name, CONF) + dst_module = utils.load_protocol_module(request.dst_module_name, CONF) + + dsturl_parts = urlparse.urlparse(request.dsturl) + writer = dst_module.get_writer(dsturl_parts, + checkpointer=checkpointer) + + # it is up to the reader/writer to put on the bw limits + srcurl_parts = urlparse.urlparse(request.srcurl) + reader = src_module.get_reader(srcurl_parts, + writer, + monitor, + request.next_ndx, + request.end_ndx) + + reader.process() + except Exception, ex: + state_machine.event_occurred(constants.Events.EVENT_ERROR, + exception=ex, + conf=CONF, + xfer_request=request, + db=db_con) + raise + finally: + state_machine.event_occurred(constants.Events.EVENT_COMPLETE, + conf=CONF, + xfer_request=request, + db=db_con) + + +class SimpleThreadExecutor(threading.Thread): + + def __init__(self, xfer_id, conf, state_machine): + super(SimpleThreadExecutor, self).__init__() + self.conf = conf + self.xfer_id = xfer_id + self.state_machine = state_machine + self.start() + + def run(self): + do_transfer(self.conf, self.xfer_id, self.state_machine) diff --git a/staccato/xfer/interface.py b/staccato/xfer/interface.py new file mode 100644 index 0000000..320b8f0 --- /dev/null +++ b/staccato/xfer/interface.py @@ -0,0 +1,62 @@ +import urlparse + +from staccato.common import utils, config +from staccato import db +from staccato.xfer.constants import Events +from staccato.xfer import events + + +def xfer_new(CONF, srcurl, dsturl, src_opts, dst_opts, start_ndx=0, + end_ndx=None): + srcurl_parts = urlparse.urlparse(srcurl) + dsturl_parts = urlparse.urlparse(dsturl) + + plugin_policy = config.get_protocol_policy(CONF) + src_module_name = utils.find_protocol_module_name(plugin_policy, + srcurl_parts) + dst_module_name = utils.find_protocol_module_name(plugin_policy, + dsturl_parts) + + src_module = utils.load_protocol_module(src_module_name, CONF) + dst_module = utils.load_protocol_module(dst_module_name, CONF) + + write_info = dst_module.new_write(dsturl_parts, dst_opts) + read_info = src_module.new_write(srcurl_parts, src_opts) + + db_con = db.StaccatoDB(CONF) + xfer = db_con.get_new_xfer(srcurl, + dsturl, + src_module_name, + dst_module_name, + start_ndx=start_ndx, + end_ndx=end_ndx, + read_info=read_info, + write_info=write_info) + return xfer + + +def xfer_start(conf, xfer_id): + db_con = db.StaccatoDB(conf) + request = db_con.lookup_xfer_request_by_id(xfer_id) + events.g_my_states.event_occurred(Events.EVENT_START, + conf=conf, + xfer_request=request, + db=db_con) + + +def xfer_cancel(conf, xfer_id): + db_con = db.StaccatoDB(conf) + request = db_con.lookup_xfer_request_by_id(xfer_id) + events.g_my_states.event_occurred(Events.EVENT_CANCEL, + conf=conf, + xfer_request=request, + db=db_con) + + +def xfer_delete(conf, xfer_id): + db_con = db.StaccatoDB(conf) + request = db_con.lookup_xfer_request_by_id(xfer_id) + events.g_my_states.event_occurred(Events.EVENT_DELETE, + conf=conf, + xfer_request=request, + db=db_con) diff --git a/staccato/xfer/utils.py b/staccato/xfer/utils.py new file mode 100644 index 0000000..91ef615 --- /dev/null +++ b/staccato/xfer/utils.py @@ -0,0 +1,127 @@ +import datetime + +import staccato.xfer.constants as constants +import staccato.common.exceptions as exceptions + + +def _merge_one(blocks): + + if not blocks: + return blocks.copy() + merge = True + while merge: + new = {} + merge = False + keys = sorted(blocks.keys()) + ndx = 0 + current_key = keys[ndx] + new[current_key] = blocks[current_key] + ndx = ndx + 1 + + while ndx < len(keys): + next_key = keys[ndx] + start_i = current_key + start_j = next_key + + end_i = blocks[start_i] + end_j = blocks[start_j] + + if end_i >= start_j: + merge = True + new[start_i] = max(end_i, end_j) + #ndx = ndx + 1 + else: + new[start_j] = end_j + current_key = next_key + + ndx = ndx + 1 + blocks = new + return new + + +class XferDBUpdater(object): + + def __init__(self, db_refresh_rate=5): + self.db_refresh_rate = db_refresh_rate + self._set_time() + + def _set_time(self): + self.next_time = datetime.datetime.now() +\ + datetime.timedelta(seconds=self.db_refresh_rate) + + def _check_db_ready(self): + n = datetime.datetime.now() + if n > self.next_time: + self._set_time() + self._do_db_operation() + + +class XferReadMonitor(XferDBUpdater): + + def __init__(self, db, xfer_id, db_refresh_rate=5): + super(XferReadMonitor, self).__init__(db_refresh_rate=db_refresh_rate) + self.db = db + self.done = True # TODO base this on xfer_request + self.xfer_id = xfer_id + self._do_db_operation() + + def _do_db_operation(self): + self.request = self.db.lookup_xfer_request_by_id(self.xfer_id) + + def is_done(self): + self._check_db_ready() + return constants.is_state_done_running(self.request.state) + + +class XferCheckpointer(XferDBUpdater): + """ + This class is used by protocol plugins to keep track of the progress of + a transfer. With each write the plugin can call update() and the blocks + will be tracked. When the protocol plugin has safely synced some data + to disk it can call sync(). Each call to sync may cause a write to the + database. + + This class will help write side connections keep track fo their workload + """ + def __init__(self, xfer_request, protocol_doc, db, db_refresh_rate=5): + """ + :param xfer_id: The transfer ID to be tracked. + :protocol doc: protocol specific information for tracking. This + should be a dict + """ + super(XferCheckpointer, self).__init__(db_refresh_rate=db_refresh_rate) + self.blocks = {} + self.db = db + self.protocol_doc = protocol_doc + self.xfer_request = xfer_request + self.update(0, 0) + + def update(self, block_start, block_end): + """ + :param block_start: the start of the block. + :param block_end: the end of the block. + """ + if block_end < block_start: + raise exceptions.StaccatoParameterError() + + if block_start in self.blocks: + self.blocks[block_start] = max(self.blocks[block_start], block_end) + else: + self.blocks[block_start] = block_end + + self.blocks = _merge_one(self.blocks) + + def _do_db_operation(self): + keys = sorted(self.blocks.keys()) + self.xfer_request.next_ndx = self.blocks[keys[0]] + self.db.save_db_obj(self.xfer_request) + + def sync(self, protocol_doc): + """ + :param protocol_doc: A update to the protocol specific information + sent in by the protocol module. This will be + merged with the last dict sent in. + """ + # take the first from the list and only sync that far. + self.protocol_doc.update(protocol_doc) + self._check_db_ready() diff --git a/tools/install_venv_common.py b/tools/install_venv_common.py new file mode 100644 index 0000000..914fcf1 --- /dev/null +++ b/tools/install_venv_common.py @@ -0,0 +1,221 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 OpenStack Foundation +# Copyright 2013 IBM Corp. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Provides methods needed by installation script for OpenStack development +virtual environments. + +Since this script is used to bootstrap a virtualenv from the system's Python +environment, it should be kept strictly compatible with Python 2.6. + +Synced in from openstack-common +""" + +from __future__ import print_function + +import optparse +import os +import subprocess +import sys + + +class InstallVenv(object): + + def __init__(self, root, venv, pip_requires, test_requires, py_version, + project): + self.root = root + self.venv = venv + self.pip_requires = pip_requires + self.test_requires = test_requires + self.py_version = py_version + self.project = project + + def die(self, message, *args): + print(message % args, file=sys.stderr) + sys.exit(1) + + def check_python_version(self): + if sys.version_info < (2, 6): + self.die("Need Python Version >= 2.6") + + def run_command_with_code(self, cmd, redirect_output=True, + check_exit_code=True): + """Runs a command in an out-of-process shell. + + Returns the output of that command. Working directory is self.root. + """ + if redirect_output: + stdout = subprocess.PIPE + else: + stdout = None + + proc = subprocess.Popen(cmd, cwd=self.root, stdout=stdout) + output = proc.communicate()[0] + if check_exit_code and proc.returncode != 0: + self.die('Command "%s" failed.\n%s', ' '.join(cmd), output) + return (output, proc.returncode) + + def run_command(self, cmd, redirect_output=True, check_exit_code=True): + return self.run_command_with_code(cmd, redirect_output, + check_exit_code)[0] + + def get_distro(self): + if (os.path.exists('/etc/fedora-release') or + os.path.exists('/etc/redhat-release')): + return Fedora(self.root, self.venv, self.pip_requires, + self.test_requires, self.py_version, self.project) + else: + return Distro(self.root, self.venv, self.pip_requires, + self.test_requires, self.py_version, self.project) + + def check_dependencies(self): + self.get_distro().install_virtualenv() + + def create_virtualenv(self, no_site_packages=True): + """Creates the virtual environment and installs PIP. + + Creates the virtual environment and installs PIP only into the + virtual environment. + """ + if not os.path.isdir(self.venv): + print('Creating venv...', end=' ') + if no_site_packages: + self.run_command(['virtualenv', '-q', '--no-site-packages', + self.venv]) + else: + self.run_command(['virtualenv', '-q', self.venv]) + print('done.') + print('Installing pip in venv...', end=' ') + if not self.run_command(['tools/with_venv.sh', 'easy_install', + 'pip>1.0']).strip(): + self.die("Failed to install pip.") + print('done.') + else: + print("venv already exists...") + pass + + def pip_install(self, *args): + self.run_command(['tools/with_venv.sh', + 'pip', 'install', '--upgrade'] + list(args), + redirect_output=False) + + def install_dependencies(self): + print('Installing dependencies with pip (this can take a while)...') + + # First things first, make sure our venv has the latest pip and + # distribute. + # NOTE: we keep pip at version 1.1 since the most recent version causes + # the .venv creation to fail. See: + # https://bugs.launchpad.net/nova/+bug/1047120 + self.pip_install('pip==1.1') + self.pip_install('distribute') + + # Install greenlet by hand - just listing it in the requires file does + # not + # get it installed in the right order + self.pip_install('greenlet') + + self.pip_install('-r', self.pip_requires) + self.pip_install('-r', self.test_requires) + + def post_process(self): + self.get_distro().post_process() + + def parse_args(self, argv): + """Parses command-line arguments.""" + parser = optparse.OptionParser() + parser.add_option('-n', '--no-site-packages', + action='store_true', + help="Do not inherit packages from global Python " + "install") + return parser.parse_args(argv[1:])[0] + + +class Distro(InstallVenv): + + def check_cmd(self, cmd): + return bool(self.run_command(['which', cmd], + check_exit_code=False).strip()) + + def install_virtualenv(self): + if self.check_cmd('virtualenv'): + return + + if self.check_cmd('easy_install'): + print('Installing virtualenv via easy_install...', end=' ') + if self.run_command(['easy_install', 'virtualenv']): + print('Succeeded') + return + else: + print('Failed') + + self.die('ERROR: virtualenv not found.\n\n%s development' + ' requires virtualenv, please install it using your' + ' favorite package management tool' % self.project) + + def post_process(self): + """Any distribution-specific post-processing gets done here. + + In particular, this is useful for applying patches to code inside + the venv. + """ + pass + + +class Fedora(Distro): + """This covers all Fedora-based distributions. + + Includes: Fedora, RHEL, CentOS, Scientific Linux + """ + + def check_pkg(self, pkg): + return self.run_command_with_code(['rpm', '-q', pkg], + check_exit_code=False)[1] == 0 + + def apply_patch(self, originalfile, patchfile): + self.run_command(['patch', '-N', originalfile, patchfile], + check_exit_code=False) + + def install_virtualenv(self): + if self.check_cmd('virtualenv'): + return + + if not self.check_pkg('python-virtualenv'): + self.die("Please install 'python-virtualenv'.") + + super(Fedora, self).install_virtualenv() + + def post_process(self): + """Workaround for a bug in eventlet. + + This currently affects RHEL6.1, but the fix can safely be + applied to all RHEL and Fedora distributions. + + This can be removed when the fix is applied upstream. + + Nova: https://bugs.launchpad.net/nova/+bug/884915 + Upstream: https://bitbucket.org/which_linden/eventlet/issue/89 + """ + + # Install "patch" program if it's not there + if not self.check_pkg('patch'): + self.die("Please install 'patch'.") + + # Apply the eventlet patch + self.apply_patch(os.path.join(self.venv, 'lib', self.py_version, + 'site-packages', + 'eventlet/green/subprocess.py'), + 'contrib/redhat-eventlet.patch')