diff --git a/.gitignore b/.gitignore index 0ea37f497..df2fc2749 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,7 @@ pip-log.txt .tox nosetests.xml coverage.xml +flake8.log # Translations *.mo diff --git a/barbican/api/middleware/context.py b/barbican/api/middleware/context.py index 5e7127931..1ea0b457c 100644 --- a/barbican/api/middleware/context.py +++ b/barbican/api/middleware/context.py @@ -16,44 +16,49 @@ # under the License. import json - -from oslo.config import cfg import webob.exc -from barbican.api import policy -from barbican.api.middleware import Middleware +from oslo.config import cfg + +from barbican.api import middleware as mw +from barbican.common import utils import barbican.context -import barbican.openstack.common.log as logging -from barbican.openstack.common.gettextutils import _ +from barbican.openstack.common import gettextutils as u +from barbican.openstack.common import policy +LOG = utils.getLogger(__name__) +# TODO(jwood) Need to figure out why config is ignored in this module. context_opts = [ cfg.BoolOpt('owner_is_tenant', default=True, - help=_('When true, this option sets the owner of an image ' - 'to be the tenant. Otherwise, the owner of the ' - ' image will be the authenticated user issuing the ' - 'request.')), + help=u._('When true, this option sets the owner of an image ' + 'to be the tenant. Otherwise, the owner of the ' + ' image will be the authenticated user issuing the ' + 'request.')), cfg.StrOpt('admin_role', default='admin', - help=_('Role used to identify an authenticated user as ' - 'administrator.')), + help=u._('Role used to identify an authenticated user as ' + 'administrator.')), cfg.BoolOpt('allow_anonymous_access', default=False, - help=_('Allow unauthenticated users to access the API with ' - 'read-only privileges. This only applies when using ' - 'ContextMiddleware.')), + help=u._('Allow unauthenticated users to access the API with ' + 'read-only privileges. This only applies when using ' + 'ContextMiddleware.')), ] + CONF = cfg.CONF CONF.register_opts(context_opts) -LOG = logging.getLogger(__name__) + +# TODO(jwood): I'd like to get the utils.getLogger(...) working instead: +# LOG = logging.getLogger(__name__) -class BaseContextMiddleware(Middleware): +class BaseContextMiddleware(mw.Middleware): def process_response(self, resp): try: request_id = resp.request.context.request_id except AttributeError: - LOG.warn(_('Unable to retrieve request id from context')) + LOG.warn(u._('Unable to retrieve request id from context')) else: resp.headers['x-openstack-request-id'] = 'req-%s' % request_id return resp @@ -87,6 +92,9 @@ class ContextMiddleware(BaseContextMiddleware): else: raise webob.exc.HTTPUnauthorized() + # Ensure that down wind mw.Middleware/app can see this context. + req.environ['barbican.context'] = req.context + def _get_anonymous_context(self): kwargs = { 'user': None, @@ -115,7 +123,7 @@ class ContextMiddleware(BaseContextMiddleware): service_catalog = json.loads(catalog_header) except ValueError: raise webob.exc.HTTPInternalServerError( - _('Invalid service catalog json.')) + u._('Invalid service catalog json.')) kwargs = { 'user': req.headers.get('X-User-Id'), diff --git a/barbican/api/policy.py b/barbican/api/policy.py deleted file mode 100644 index fb076b024..000000000 --- a/barbican/api/policy.py +++ /dev/null @@ -1,156 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright (c) 2013 OpenStack, LLC. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -"""Policy Engine For Barbican""" - -import json -import os.path - -from oslo.config import cfg - -from barbican.common import exception -import barbican.openstack.common.log as logging -from barbican.openstack.common import policy -from barbican.openstack.common.gettextutils import _ - -LOG = logging.getLogger(__name__) - -policy_opts = [ - cfg.StrOpt('policy_file', default='policy.json', - help=_('The location of the policy file.')), - cfg.StrOpt('policy_default_rule', default='default', - help=_('The default policy to use.')), -] - -CONF = cfg.CONF -CONF.register_opts(policy_opts) - - -DEFAULT_RULES = { - 'context_is_admin': policy.RoleCheck('role', 'admin'), - 'default': policy.TrueCheck(), - 'manage_re_key': policy.RoleCheck('role', 'admin'), -} - - -class Enforcer(object): - """Responsible for loading and enforcing rules""" - - def __init__(self): - self.default_rule = CONF.policy_default_rule - self.policy_path = self._find_policy_file() - self.policy_file_mtime = None - self.policy_file_contents = None - - def set_rules(self, rules): - """Create a new Rules object based on the provided dict of rules""" - rules_obj = policy.Rules(rules, self.default_rule) - policy.set_rules(rules_obj) - - def load_rules(self): - """Set the rules found in the json file on disk""" - if self.policy_path: - rules = self._read_policy_file() - rule_type = "" - else: - rules = DEFAULT_RULES - rule_type = "default " - - text_rules = dict((k, str(v)) for k, v in rules.items()) - LOG.debug(_('Loaded %(rule_type)spolicy rules: %(text_rules)s') % - locals()) - - self.set_rules(rules) - - @staticmethod - def _find_policy_file(): - """Locate the policy json data file""" - policy_file = CONF.find_file(CONF.policy_file) - if policy_file: - return policy_file - else: - LOG.warn(_('Unable to find policy file')) - return None - - def _read_policy_file(self): - """Read contents of the policy file - - This re-caches policy data if the file has been changed. - """ - mtime = os.path.getmtime(self.policy_path) - if not self.policy_file_contents or mtime != self.policy_file_mtime: - LOG.debug(_("Loading policy from %s") % self.policy_path) - with open(self.policy_path) as fap: - raw_contents = fap.read() - rules_dict = json.loads(raw_contents) - self.policy_file_contents = dict( - (k, policy.parse_rule(v)) - for k, v in rules_dict.items()) - self.policy_file_mtime = mtime - return self.policy_file_contents - - def _check(self, context, rule, target, *args, **kwargs): - """Verifies that the action is valid on the target in this context. - - :param context: Barbican request context - :param rule: String representing the action to be checked - :param object: Dictionary representing the object of the action. - :raises: `barbican.common.exception.Forbidden` - :returns: A non-False value if access is allowed. - """ - self.load_rules() - - credentials = { - 'roles': context.roles, - 'user': context.user, - 'tenant': context.tenant, - } - - return policy.check(rule, target, credentials, *args, **kwargs) - - def enforce(self, context, action, target): - """Verifies that the action is valid on the target in this context. - - :param context: Barbican request context - :param action: String representing the action to be checked - :param object: Dictionary representing the object of the action. - :raises: `barbican.common.exception.Forbidden` - :returns: A non-False value if access is allowed. - """ - LOG.debug("== policy.enforce satisfied ==") - return self._check(context, action, target, - exception.Forbidden, action=action) - - def check(self, context, action, target): - """Verifies that the action is valid on the target in this context. - - :param context: Barbican request context - :param action: String representing the action to be checked - :param object: Dictionary representing the object of the action. - :returns: A non-False value if access is allowed. - """ - return self._check(context, action, target) - - def check_is_admin(self, context): - """Check if the given context is associated with an admin role, - as defined via the 'context_is_admin' RBAC rule. - - :param context: request context - :returns: A non-False value if context role is admin. - """ - target = context.to_dict() - return self.check(context, 'context_is_admin', target) diff --git a/barbican/api/resources.py b/barbican/api/resources.py index c6f29e878..7e58baacc 100644 --- a/barbican/api/resources.py +++ b/barbican/api/resources.py @@ -21,19 +21,19 @@ import base64 import falcon from barbican import api -from barbican.api.policy import Enforcer -from barbican.common import resources as res -from barbican.common import utils, validators -from barbican.crypto.mime_types import augment_fields_with_content_types -from barbican.model import models -from barbican.model import repositories as repo from barbican.common import exception +from barbican.common import resources as res +from barbican.common import utils +from barbican.common import validators from barbican.crypto import extension_manager as em from barbican.crypto import mime_types -from barbican.openstack.common.gettextutils import _ +from barbican.model import models +from barbican.model import repositories as repo +from barbican.openstack.common import gettextutils as u from barbican.openstack.common import jsonutils as json -from barbican.queue import get_queue_api -from barbican.version import __version__ +from barbican.openstack.common import policy +from barbican import queue +from barbican import version LOG = utils.getLogger(__name__) @@ -41,95 +41,92 @@ LOG = utils.getLogger(__name__) def _general_failure(message, req, resp): """Throw exception a general processing failure.""" - api.abort(falcon.HTTP_500, _(message), req, resp) + api.abort(falcon.HTTP_500, message, req, resp) def _secret_not_found(req, resp): """Throw exception indicating secret not found.""" - api.abort(falcon.HTTP_404, _('Unable to locate secret.'), req, resp) + api.abort(falcon.HTTP_404, u._('Unable to locate secret.'), req, resp) def _order_not_found(req, resp): """Throw exception indicating order not found.""" - api.abort(falcon.HTTP_404, _('Unable to locate order.'), req, resp) + api.abort(falcon.HTTP_404, u._('Unable to locate order.'), req, resp) def _put_accept_incorrect(ct, req, resp): """Throw exception indicating request content-type is not supported.""" api.abort(falcon.HTTP_415, - _("Content-Type of '{0}' is not supported.").format(ct), + u._("Content-Type of '{0}' is not supported.").format(ct), req, resp) def _get_accept_not_supported(accept, req, resp): """Throw exception indicating request's accept is not supported.""" api.abort(falcon.HTTP_406, - _("Accept of '{0}' is not supported.").format(accept), + u._("Accept of '{0}' is not supported.").format(accept), req, resp) def _get_secret_info_not_found(mime_type, req, resp): """Throw exception indicating request's accept is not supported.""" api.abort(falcon.HTTP_404, - _("Secret information of type '{0}' not available for " - "decryption.").format(mime_type), + u._("Secret information of type '{0}' not available for " + "decryption.").format(mime_type), req, resp) def _secret_mime_type_not_supported(mt, req, resp): """Throw exception indicating secret mime-type is not supported.""" api.abort(falcon.HTTP_400, - _("Mime-type of '{0}' is not supported.").format(mt), req, resp) + u._("Mime-type of '{0}' " + "is not supported.").format(mt), req, resp) def _secret_data_too_large(req, resp): """Throw exception indicating plain-text was too big.""" api.abort(falcon.HTTP_413, - _("Could not add secret data as it was too large"), req, resp) + u._("Could not add secret data as it was too large"), req, resp) def _secret_plain_text_empty(req, resp): """Throw exception indicating empty plain-text was supplied.""" api.abort(falcon.HTTP_400, - _("Could not add secret with empty 'plain_text'"), req, resp) + u._("Could not add secret with empty 'plain_text'"), req, resp) def _failed_to_create_encrypted_datum(req, resp): - """ - Throw exception we could not create an EncryptedDatum - record for the secret. - """ + """Throw exception could not create EncryptedDatum record for secret.""" api.abort(falcon.HTTP_400, - _("Could not add secret data to Barbican."), req, resp) + u._("Could not add secret data to Barbican."), req, resp) def _failed_to_decrypt_data(req, resp): """Throw exception if failed to decrypt secret information.""" api.abort(falcon.HTTP_500, - _("Problem decrypting secret information."), req, resp) + u._("Problem decrypting secret information."), req, resp) def _secret_already_has_data(req, resp): - """ - Throw exception that the secret already has data. - """ + """Throw exception that the secret already has data.""" api.abort(falcon.HTTP_409, - _("Secret already has data, cannot modify it."), req, resp) + u._("Secret already has data, cannot modify it."), req, resp) def _secret_not_in_order(req, resp): - """ - Throw exception that secret information is not available in the order. - """ + """Throw exception that secret info is not available in the order.""" api.abort(falcon.HTTP_400, - _("Secret metadata expected but not received."), req, resp) + u._("Secret metadata expected but not received."), req, resp) def _secret_create_failed(req, resp): - """ - Throw exception that secret creation attempt failed. - """ - api.abort(falcon.HTTP_500, _("Unabled to create secret."), req, resp) + """Throw exception that secret creation attempt failed.""" + api.abort(falcon.HTTP_500, u._("Unabled to create secret."), req, resp) + + +def _authorization_failed(message, req, resp): + """Throw exception that authorization failed.""" + api.abort(falcon.HTTP_401, message, req, resp) def json_handler(obj): @@ -138,7 +135,7 @@ def json_handler(obj): def convert_secret_to_href(keystone_id, secret_id): - """Convert the tenant/secret IDs to a HATEOS-style href""" + """Convert the tenant/secret IDs to a HATEOS-style href.""" if secret_id: resource = 'secrets/' + secret_id else: @@ -147,7 +144,7 @@ def convert_secret_to_href(keystone_id, secret_id): def convert_order_to_href(keystone_id, order_id): - """Convert the tenant/order IDs to a HATEOS-style href""" + """Convert the tenant/order IDs to a HATEOS-style href.""" if order_id: resource = 'orders/' + order_id else: @@ -156,7 +153,7 @@ def convert_order_to_href(keystone_id, order_id): def convert_to_hrefs(keystone_id, fields): - """Convert id's within a fields dict to HATEOS-style hrefs""" + """Convert id's within a fields dict to HATEOS-style hrefs.""" if 'secret_id' in fields: fields['secret_ref'] = convert_secret_to_href(keystone_id, fields['secret_id']) @@ -169,7 +166,8 @@ def convert_to_hrefs(keystone_id, fields): def convert_list_to_href(resources_name, keystone_id, offset, limit): - """ + """Supports pretty output of paged-list hrefs. + Convert the tenant ID and offset/limit info to a HATEOS-style href suitable for use in a list navigation paging interface. """ @@ -179,7 +177,8 @@ def convert_list_to_href(resources_name, keystone_id, offset, limit): def previous_href(resources_name, keystone_id, offset, limit): - """ + """Supports pretty output of previous-page hrefs. + Create a HATEOS-style 'previous' href suitable for use in a list navigation paging interface, assuming the provided values are the currently viewed page. @@ -189,7 +188,8 @@ def previous_href(resources_name, keystone_id, offset, limit): def next_href(resources_name, keystone_id, offset, limit): - """ + """Supports pretty output of next-page hrefs. + Create a HATEOS-style 'next' href suitable for use in a list navigation paging interface, assuming the provided values are the currently viewed page. @@ -213,7 +213,64 @@ def add_nav_hrefs(resources_name, keystone_id, offset, limit, return data -def handle_exceptions(operation_name=_('System')): +def is_json_request_accept(req): + """Test if http request 'accept' header configured for JSON response. + + :param req: HTTP request + :return: True if need to return JSON response. + """ + return not req.accept or req.accept == 'application/json' \ + or req.accept == '*/*' + + +def enforce_rbac(req, resp, action_name, keystone_id=None): + """Enforce RBAC based on 'request' information.""" + if action_name and 'barbican.context' in req.env: + + # Prepare credentials information. + ctx = req.env['barbican.context'] # Placed here by context.py + # middleware + credentials = { + 'roles': ctx.roles, + 'user': ctx.user, + 'tenant': ctx.tenant, + } + + # Verify keystone_id matches the tenant ID. + if keystone_id and keystone_id != ctx.tenant: + _authorization_failed(u._("URI tenant does not match " + "authenticated tenant."), req, resp) + + # Enforce special case: secret GET decryption + if 'secret:get' == action_name and not is_json_request_accept(req): + action_name = 'secret:decrypt' # Override to perform special rules + + # Enforce access controls. + ctx.policy_enforcer.enforce(action_name, {}, credentials, + do_raise=True) + + +def handle_rbac(action_name='default'): + """ + Decorator that handles RBAC enforcement on behalf of REST verb methods. + """ + + def rbac_decorator(fn): + def enforcer(inst, req, resp, *args, **kwargs): + + # Enforce RBAC rules. + enforce_rbac(req, resp, action_name, + keystone_id=kwargs.get('keystone_id')) + + # Execute guarded method now. + fn(inst, req, resp, *args, **kwargs) + + return enforcer + + return rbac_decorator + + +def handle_exceptions(operation_name=u._('System')): """ Handle general exceptions to avoid a response code of 0 back to clients. @@ -226,9 +283,15 @@ def handle_exceptions(operation_name=_('System')): except falcon.HTTPError as f: LOG.exception('Falcon error seen') raise f # Already converted to Falcon exception, just reraise + except policy.PolicyNotAuthorized: + message = u._('{0} attempt was not authorized - ' + 'please review your ' + 'user/tenant privileges').format(operation_name) + LOG.exception(message) + _authorization_failed(message, req, resp) except Exception: - message = _('{0} failure seen - please contact site ' - 'administrator').format(operation_name) + message = u._('{0} failure seen - please contact site ' + 'administrator').format(operation_name) LOG.exception(message) _general_failure(message, req, resp) @@ -251,14 +314,15 @@ class PerformanceResource(api.ApiResource): class VersionResource(api.ApiResource): """Returns service and build version information""" - def __init__(self, policy_enforcer=None): + def __init__(self): LOG.debug('=== Creating VersionResource ===') - self.policy = policy_enforcer or Enforcer() + @handle_exceptions(u._('Version retrieval')) + @handle_rbac('version:get') def on_get(self, req, resp): resp.status = falcon.HTTP_200 resp.body = json.dumps({'v1': 'current', - 'build': __version__}) + 'build': version.__version__}) class SecretsResource(api.ApiResource): @@ -266,19 +330,18 @@ class SecretsResource(api.ApiResource): def __init__(self, crypto_manager, tenant_repo=None, secret_repo=None, - tenant_secret_repo=None, datum_repo=None, kek_repo=None, - policy_enforcer=None): + tenant_secret_repo=None, datum_repo=None, kek_repo=None): LOG.debug('Creating SecretsResource') self.tenant_repo = tenant_repo or repo.TenantRepo() self.secret_repo = secret_repo or repo.SecretRepo() self.tenant_secret_repo = tenant_secret_repo or repo.TenantSecretRepo() self.datum_repo = datum_repo or repo.EncryptedDatumRepo() self.kek_repo = kek_repo or repo.KEKDatumRepo() - self.policy = policy_enforcer or Enforcer() self.crypto_manager = crypto_manager self.validator = validators.NewSecretValidator() - @handle_exceptions(_('Secret creation')) + @handle_exceptions(u._('Secret creation')) + @handle_rbac('secrets:post') def on_post(self, req, resp, keystone_id): LOG.debug('Start on_post for tenant-ID {0}:...'.format(keystone_id)) @@ -302,7 +365,8 @@ class SecretsResource(api.ApiResource): _secret_data_too_large(req, resp) except Exception: LOG.exception('Secret creation failed - unknown') - _general_failure('Secret creation failed - unknown', req, resp) + _general_failure(u._('Secret creation failed - unknown'), req, + resp) resp.status = falcon.HTTP_201 resp.set_header('Location', '/{0}/secrets/{1}'.format(keystone_id, @@ -311,7 +375,8 @@ class SecretsResource(api.ApiResource): LOG.debug('URI to secret is {0}'.format(url)) resp.body = json.dumps({'secret_ref': url}) - @handle_exceptions(_('Secret(s) retrieval')) + @handle_exceptions(u._('Secret(s) retrieval')) + @handle_rbac('secrets:get') def on_get(self, req, resp, keystone_id): LOG.debug('Start secrets on_get ' 'for tenant-ID {0}:'.format(keystone_id)) @@ -328,7 +393,8 @@ class SecretsResource(api.ApiResource): if not secrets: secrets_resp_overall = {'secrets': []} else: - secret_fields = lambda s: augment_fields_with_content_types(s) + secret_fields = lambda s: mime_types\ + .augment_fields_with_content_types(s) secrets_resp = [convert_to_hrefs(keystone_id, secret_fields(s)) for s in secrets] secrets_resp_overall = add_nav_hrefs('secrets', keystone_id, @@ -345,17 +411,16 @@ class SecretResource(api.ApiResource): def __init__(self, crypto_manager, tenant_repo=None, secret_repo=None, - tenant_secret_repo=None, datum_repo=None, kek_repo=None, - policy_enforcer=None): + tenant_secret_repo=None, datum_repo=None, kek_repo=None): self.crypto_manager = crypto_manager self.tenant_repo = tenant_repo or repo.TenantRepo() self.repo = secret_repo or repo.SecretRepo() self.tenant_secret_repo = tenant_secret_repo or repo.TenantSecretRepo() self.datum_repo = datum_repo or repo.EncryptedDatumRepo() self.kek_repo = kek_repo or repo.KEKDatumRepo() - self.policy = policy_enforcer or Enforcer() - @handle_exceptions(_('Secret retrieval')) + @handle_exceptions(u._('Secret retrieval')) + @handle_rbac('secret:get') def on_get(self, req, resp, keystone_id, secret_id): secret = self.repo.get(entity_id=secret_id, keystone_id=keystone_id, @@ -365,11 +430,11 @@ class SecretResource(api.ApiResource): resp.status = falcon.HTTP_200 - if not req.accept or req.accept == 'application/json' \ - or req.accept == '*/*': + if is_json_request_accept(req): # Metadata-only response, no decryption necessary. resp.set_header('Content-Type', 'application/json') - secret_fields = augment_fields_with_content_types(secret) + secret_fields = mime_types.augment_fields_with_content_types( + secret) resp.body = json.dumps(convert_to_hrefs(keystone_id, secret_fields), default=json_handler) @@ -412,7 +477,8 @@ class SecretResource(api.ApiResource): ' {0}'.format(str(encodings))) _get_accept_not_supported(str(encodings), req, resp) - @handle_exceptions(_('Secret update')) + @handle_exceptions(u._('Secret update')) + @handle_rbac('secret:put') def on_put(self, req, resp, keystone_id, secret_id): if not req.content_type or req.content_type == 'application/json': @@ -471,7 +537,8 @@ class SecretResource(api.ApiResource): resp.status = falcon.HTTP_200 - @handle_exceptions(_('Secret deletion')) + @handle_exceptions(u._('Secret deletion')) + @handle_rbac('secret:delete') def on_delete(self, req, resp, keystone_id, secret_id): try: @@ -488,16 +555,16 @@ class OrdersResource(api.ApiResource): """Handles Order requests for Secret creation""" def __init__(self, tenant_repo=None, order_repo=None, - queue_resource=None, policy_enforcer=None): + queue_resource=None): LOG.debug('Creating OrdersResource') self.tenant_repo = tenant_repo or repo.TenantRepo() self.order_repo = order_repo or repo.OrderRepo() - self.queue = queue_resource or get_queue_api() - self.policy = policy_enforcer or Enforcer() + self.queue = queue_resource or queue.get_queue_api() self.validator = validators.NewOrderValidator() - @handle_exceptions(_('Order creation')) + @handle_exceptions(u._('Order creation')) + @handle_rbac('orders:post') def on_post(self, req, resp, keystone_id): tenant = res.get_or_create_tenant(keystone_id, self.tenant_repo) @@ -533,7 +600,8 @@ class OrdersResource(api.ApiResource): url = convert_order_to_href(keystone_id, new_order.id) resp.body = json.dumps({'order_ref': url}) - @handle_exceptions(_('Order(s) retrieval')) + @handle_exceptions(u._('Order(s) retrieval')) + @handle_rbac('orders:get') def on_get(self, req, resp, keystone_id): LOG.debug('Start orders on_get ' 'for tenant-ID {0}:'.format(keystone_id)) @@ -564,11 +632,11 @@ class OrdersResource(api.ApiResource): class OrderResource(api.ApiResource): """Handles Order retrieval and deletion requests""" - def __init__(self, order_repo=None, policy_enforcer=None): + def __init__(self, order_repo=None): self.repo = order_repo or repo.OrderRepo() - self.policy = policy_enforcer or Enforcer() - @handle_exceptions(_('Order retrieval')) + @handle_exceptions(u._('Order retrieval')) + @handle_rbac('order:get') def on_get(self, req, resp, keystone_id, order_id): order = self.repo.get(entity_id=order_id, keystone_id=keystone_id, suppress_exception=True) @@ -580,7 +648,8 @@ class OrderResource(api.ApiResource): order.to_dict_fields()), default=json_handler) - @handle_exceptions(_('Order deletion')) + @handle_exceptions(u._('Order deletion')) + @handle_rbac('order:delete') def on_delete(self, req, resp, keystone_id, order_id): try: diff --git a/barbican/context.py b/barbican/context.py index f7332db52..8dce8e582 100644 --- a/barbican/context.py +++ b/barbican/context.py @@ -15,7 +15,7 @@ # License for the specific language governing permissions and limitations # under the License. -from barbican.api import policy +from barbican.openstack.common import policy from barbican.openstack.common import local from barbican.openstack.common import uuidutils @@ -35,7 +35,7 @@ class RequestContext(object): self.tenant = tenant self.roles = roles or [] self.read_only = read_only - self._show_deleted = show_deleted + # TODO(jwood): self._show_deleted = show_deleted # (mkbhanda) possibly domain could be owner # brings us to the key scope question self.owner_is_tenant = owner_is_tenant @@ -43,9 +43,10 @@ class RequestContext(object): self.service_catalog = service_catalog self.policy_enforcer = policy_enforcer or policy.Enforcer() self.is_admin = is_admin - if not self.is_admin: - self.is_admin = \ - self.policy_enforcer.check_is_admin(self) + # TODO(jwood): Is this needed? + # if not self.is_admin: + # self.is_admin = \ + # self.policy_enforcer.check_is_admin(self) if not hasattr(local.store, 'context'): self.update_store() @@ -64,9 +65,8 @@ class RequestContext(object): 'tenant': self.tenant, 'tenant_id': self.tenant, 'project_id': self.tenant, - - 'is_admin': self.is_admin, - 'read_deleted': self.show_deleted, + # TODO(jwood): 'is_admin': self.is_admin, + # TODO(jwood): 'read_deleted': self.show_deleted, 'roles': self.roles, 'auth_token': self.auth_tok, 'service_catalog': self.service_catalog, @@ -84,9 +84,10 @@ class RequestContext(object): """Return the owner to correlate with key.""" return self.tenant if self.owner_is_tenant else self.user - @property - def show_deleted(self): - """Admins can see deleted by default""" - if self._show_deleted or self.is_admin: - return True - return False +# TODO(jwood): +# @property +# def show_deleted(self): +# """Admins can see deleted by default""" +# if self._show_deleted or self.is_admin: +# return True +# return False diff --git a/barbican/openstack/common/context.py b/barbican/openstack/common/context.py index e9cfd73cc..74c6bc0c4 100644 --- a/barbican/openstack/common/context.py +++ b/barbican/openstack/common/context.py @@ -23,16 +23,18 @@ context or provide additional information in their specific WSGI pipeline. """ import itertools -import uuid + +from barbican.openstack.common import uuidutils def generate_request_id(): - return 'req-' + str(uuid.uuid4()) + return 'req-%s' % uuidutils.generate_uuid() class RequestContext(object): - """ + """Helper class to represent useful information about a request context. + Stores information about the security context under which the user accesses the system, as well as additional request information. """ @@ -59,7 +61,7 @@ class RequestContext(object): 'request_id': self.request_id} -def get_admin_context(show_deleted="no"): +def get_admin_context(show_deleted=False): context = RequestContext(None, tenant=None, is_admin=True, diff --git a/barbican/openstack/common/crypto/__init__.py b/barbican/openstack/common/crypto/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/barbican/openstack/common/crypto/utils.py b/barbican/openstack/common/crypto/utils.py new file mode 100644 index 000000000..a2cce4910 --- /dev/null +++ b/barbican/openstack/common/crypto/utils.py @@ -0,0 +1,179 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import base64 + +from Crypto.Hash import HMAC +from Crypto import Random + +from barbican.openstack.common.gettextutils import _ # noqa +from barbican.openstack.common import importutils + + +class CryptoutilsException(Exception): + """Generic Exception for Crypto utilities.""" + + message = _("An unknown error occurred in crypto utils.") + + +class CipherBlockLengthTooBig(CryptoutilsException): + """The block size is too big.""" + + def __init__(self, requested, permitted): + msg = _("Block size of %(given)d is too big, max = %(maximum)d") + message = msg % {'given': requested, 'maximum': permitted} + super(CryptoutilsException, self).__init__(message) + + +class HKDFOutputLengthTooLong(CryptoutilsException): + """The amount of Key Material asked is too much.""" + + def __init__(self, requested, permitted): + msg = _("Length of %(given)d is too long, max = %(maximum)d") + message = msg % {'given': requested, 'maximum': permitted} + super(CryptoutilsException, self).__init__(message) + + +class HKDF(object): + """An HMAC-based Key Derivation Function implementation (RFC5869) + + This class creates an object that allows to use HKDF to derive keys. + """ + + def __init__(self, hashtype='SHA256'): + self.hashfn = importutils.import_module('Crypto.Hash.' + hashtype) + self.max_okm_length = 255 * self.hashfn.digest_size + + def extract(self, ikm, salt=None): + """An extract function that can be used to derive a robust key given + weak Input Key Material (IKM) which could be a password. + Returns a pseudorandom key (of HashLen octets) + + :param ikm: input keying material (ex a password) + :param salt: optional salt value (a non-secret random value) + """ + if salt is None: + salt = '\x00' * self.hashfn.digest_size + + return HMAC.new(salt, ikm, self.hashfn).digest() + + def expand(self, prk, info, length): + """An expand function that will return arbitrary length output that can + be used as keys. + Returns a buffer usable as key material. + + :param prk: a pseudorandom key of at least HashLen octets + :param info: optional string (can be a zero-length string) + :param length: length of output keying material (<= 255 * HashLen) + """ + if length > self.max_okm_length: + raise HKDFOutputLengthTooLong(length, self.max_okm_length) + + N = (length + self.hashfn.digest_size - 1) / self.hashfn.digest_size + + okm = "" + tmp = "" + for block in range(1, N + 1): + tmp = HMAC.new(prk, tmp + info + chr(block), self.hashfn).digest() + okm += tmp + + return okm[:length] + + +MAX_CB_SIZE = 256 + + +class SymmetricCrypto(object): + """Symmetric Key Crypto object. + + This class creates a Symmetric Key Crypto object that can be used + to encrypt, decrypt, or sign arbitrary data. + + :param enctype: Encryption Cipher name (default: AES) + :param hashtype: Hash/HMAC type name (default: SHA256) + """ + + def __init__(self, enctype='AES', hashtype='SHA256'): + self.cipher = importutils.import_module('Crypto.Cipher.' + enctype) + self.hashfn = importutils.import_module('Crypto.Hash.' + hashtype) + + def new_key(self, size): + return Random.new().read(size) + + def encrypt(self, key, msg, b64encode=True): + """Encrypt the provided msg and returns the cyphertext optionally + base64 encoded. + + Uses AES-128-CBC with a Random IV by default. + + The plaintext is padded to reach blocksize length. + The last byte of the block is the length of the padding. + The length of the padding does not include the length byte itself. + + :param key: The Encryption key. + :param msg: the plain text. + + :returns encblock: a block of encrypted data. + """ + iv = Random.new().read(self.cipher.block_size) + cipher = self.cipher.new(key, self.cipher.MODE_CBC, iv) + + # CBC mode requires a fixed block size. Append padding and length of + # padding. + if self.cipher.block_size > MAX_CB_SIZE: + raise CipherBlockLengthTooBig(self.cipher.block_size, MAX_CB_SIZE) + r = len(msg) % self.cipher.block_size + padlen = self.cipher.block_size - r - 1 + msg += '\x00' * padlen + msg += chr(padlen) + + enc = iv + cipher.encrypt(msg) + if b64encode: + enc = base64.b64encode(enc) + return enc + + def decrypt(self, key, msg, b64decode=True): + """Decrypts the provided ciphertext, optionally base 64 encoded, and + returns the plaintext message, after padding is removed. + + Uses AES-128-CBC with an IV by default. + + :param key: The Encryption key. + :param msg: the ciphetext, the first block is the IV + """ + if b64decode: + msg = base64.b64decode(msg) + iv = msg[:self.cipher.block_size] + cipher = self.cipher.new(key, self.cipher.MODE_CBC, iv) + + padded = cipher.decrypt(msg[self.cipher.block_size:]) + l = ord(padded[-1]) + 1 + plain = padded[:-l] + return plain + + def sign(self, key, msg, b64encode=True): + """Signs a message string and returns a base64 encoded signature. + + Uses HMAC-SHA-256 by default. + + :param key: The Signing key. + :param msg: the message to sign. + """ + h = HMAC.new(key, msg, self.hashfn) + out = h.digest() + if b64encode: + out = base64.b64encode(out) + return out diff --git a/barbican/openstack/common/eventlet_backdoor.py b/barbican/openstack/common/eventlet_backdoor.py index c0ad460fe..a5426b5b1 100644 --- a/barbican/openstack/common/eventlet_backdoor.py +++ b/barbican/openstack/common/eventlet_backdoor.py @@ -16,8 +16,13 @@ # License for the specific language governing permissions and limitations # under the License. +from __future__ import print_function + +import errno import gc +import os import pprint +import socket import sys import traceback @@ -26,18 +31,38 @@ import eventlet.backdoor import greenlet from oslo.config import cfg +from barbican.openstack.common.gettextutils import _ # noqa +from barbican.openstack.common import log as logging + +help_for_backdoor_port = ( + "Acceptable values are 0, , and :, where 0 results " + "in listening on a random tcp port number; results in listening " + "on the specified port number (and not enabling backdoor if that port " + "is in use); and : results in listening on the smallest " + "unused port number within the specified range of port numbers. The " + "chosen port is displayed in the service's log file.") eventlet_backdoor_opts = [ - cfg.IntOpt('backdoor_port', + cfg.StrOpt('backdoor_port', default=None, - help='port for eventlet backdoor to listen') + help="Enable eventlet backdoor. %s" % help_for_backdoor_port) ] CONF = cfg.CONF CONF.register_opts(eventlet_backdoor_opts) +LOG = logging.getLogger(__name__) + + +class EventletBackdoorConfigValueError(Exception): + def __init__(self, port_range, help_msg, ex): + msg = ('Invalid backdoor_port configuration %(range)s: %(ex)s. ' + '%(help)s' % + {'range': port_range, 'ex': ex, 'help': help_msg}) + super(EventletBackdoorConfigValueError, self).__init__(msg) + self.port_range = port_range def _dont_use_this(): - print "Don't use this, just disconnect instead" + print("Don't use this, just disconnect instead") def _find_objects(t): @@ -46,16 +71,43 @@ def _find_objects(t): def _print_greenthreads(): for i, gt in enumerate(_find_objects(greenlet.greenlet)): - print i, gt + print(i, gt) traceback.print_stack(gt.gr_frame) - print + print() def _print_nativethreads(): for threadId, stack in sys._current_frames().items(): - print threadId + print(threadId) traceback.print_stack(stack) - print + print() + + +def _parse_port_range(port_range): + if ':' not in port_range: + start, end = port_range, port_range + else: + start, end = port_range.split(':', 1) + try: + start, end = int(start), int(end) + if end < start: + raise ValueError + return start, end + except ValueError as ex: + raise EventletBackdoorConfigValueError(port_range, ex, + help_for_backdoor_port) + + +def _listen(host, start_port, end_port, listen_func): + try_port = start_port + while True: + try: + return listen_func((host, try_port)) + except socket.error as exc: + if (exc.errno != errno.EADDRINUSE or + try_port >= end_port): + raise + try_port += 1 def initialize_if_enabled(): @@ -70,6 +122,8 @@ def initialize_if_enabled(): if CONF.backdoor_port is None: return None + start_port, end_port = _parse_port_range(str(CONF.backdoor_port)) + # NOTE(johannes): The standard sys.displayhook will print the value of # the last expression and set it to __builtin__._, which overwrites # the __builtin__._ that gettext sets. Let's switch to using pprint @@ -80,8 +134,13 @@ def initialize_if_enabled(): pprint.pprint(val) sys.displayhook = displayhook - sock = eventlet.listen(('localhost', CONF.backdoor_port)) + sock = _listen('localhost', start_port, end_port, eventlet.listen) + + # In the case of backdoor port being zero, a port number is assigned by + # listen(). In any case, pull the port number out here. port = sock.getsockname()[1] + LOG.info(_('Eventlet backdoor listening on %(port)s for process %(pid)d') % + {'port': port, 'pid': os.getpid()}) eventlet.spawn_n(eventlet.backdoor.backdoor_server, sock, locals=backdoor_locals) return port diff --git a/barbican/openstack/common/excutils.py b/barbican/openstack/common/excutils.py index 808d9f3a4..101c86776 100644 --- a/barbican/openstack/common/excutils.py +++ b/barbican/openstack/common/excutils.py @@ -19,16 +19,15 @@ Exception related utilities. """ -import contextlib import logging import sys +import time import traceback -from barbican.openstack.common.gettextutils import _ +from barbican.openstack.common.gettextutils import _ # noqa -@contextlib.contextmanager -def save_and_reraise_exception(): +class save_and_reraise_exception(object): """Save current exception, run some code and then re-raise. In some cases the exception context can be cleared, resulting in None @@ -40,12 +39,61 @@ def save_and_reraise_exception(): To work around this, we save the exception state, run handler code, and then re-raise the original exception. If another exception occurs, the saved exception is logged and the new exception is re-raised. - """ - type_, value, tb = sys.exc_info() - try: - yield + + In some cases the caller may not want to re-raise the exception, and + for those circumstances this context provides a reraise flag that + can be used to suppress the exception. For example: + except Exception: - logging.error(_('Original exception being dropped: %s'), - traceback.format_exception(type_, value, tb)) - raise - raise type_, value, tb + with save_and_reraise_exception() as ctxt: + decide_if_need_reraise() + if not should_be_reraised: + ctxt.reraise = False + """ + def __init__(self): + self.reraise = True + + def __enter__(self): + self.type_, self.value, self.tb, = sys.exc_info() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is not None: + logging.error(_('Original exception being dropped: %s'), + traceback.format_exception(self.type_, + self.value, + self.tb)) + return False + if self.reraise: + raise self.type_, self.value, self.tb + + +def forever_retry_uncaught_exceptions(infunc): + def inner_func(*args, **kwargs): + last_log_time = 0 + last_exc_message = None + exc_count = 0 + while True: + try: + return infunc(*args, **kwargs) + except Exception as exc: + this_exc_message = unicode(exc) + if this_exc_message == last_exc_message: + exc_count += 1 + else: + exc_count = 1 + # Do not log any more frequently than once a minute unless + # the exception message changes + cur_time = int(time.time()) + if (cur_time - last_log_time > 60 or + this_exc_message != last_exc_message): + logging.exception( + _('Unexpected exception occurred %d time(s)... ' + 'retrying.') % exc_count) + last_log_time = cur_time + last_exc_message = this_exc_message + exc_count = 0 + # This should be a very rare event. In case it isn't, do + # a sleep. + time.sleep(1) + return inner_func diff --git a/barbican/openstack/common/fileutils.py b/barbican/openstack/common/fileutils.py new file mode 100644 index 000000000..1325d279c --- /dev/null +++ b/barbican/openstack/common/fileutils.py @@ -0,0 +1,110 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +import contextlib +import errno +import os + +from barbican.openstack.common import excutils +from barbican.openstack.common.gettextutils import _ # noqa +from barbican.openstack.common import log as logging + +LOG = logging.getLogger(__name__) + +_FILE_CACHE = {} + + +def ensure_tree(path): + """Create a directory (and any ancestor directories required) + + :param path: Directory to create + """ + try: + os.makedirs(path) + except OSError as exc: + if exc.errno == errno.EEXIST: + if not os.path.isdir(path): + raise + else: + raise + + +def read_cached_file(filename, force_reload=False): + """Read from a file if it has been modified. + + :param force_reload: Whether to reload the file. + :returns: A tuple with a boolean specifying if the data is fresh + or not. + """ + global _FILE_CACHE + + if force_reload and filename in _FILE_CACHE: + del _FILE_CACHE[filename] + + reloaded = False + mtime = os.path.getmtime(filename) + cache_info = _FILE_CACHE.setdefault(filename, {}) + + if not cache_info or mtime > cache_info.get('mtime', 0): + LOG.debug(_("Reloading cached file %s") % filename) + with open(filename) as fap: + cache_info['data'] = fap.read() + cache_info['mtime'] = mtime + reloaded = True + return (reloaded, cache_info['data']) + + +def delete_if_exists(path): + """Delete a file, but ignore file not found error. + + :param path: File to delete + """ + + try: + os.unlink(path) + except OSError as e: + if e.errno == errno.ENOENT: + return + else: + raise + + +@contextlib.contextmanager +def remove_path_on_error(path): + """Protect code that wants to operate on PATH atomically. + Any exception will cause PATH to be removed. + + :param path: File to work with + """ + try: + yield + except Exception: + with excutils.save_and_reraise_exception(): + delete_if_exists(path) + + +def file_open(*args, **kwargs): + """Open file + + see built-in file() documentation for more details + + Note: The reason this is kept in a separate module is to easily + be able to provide a stub module that doesn't alter system + state at all (for unit tests) + """ + return file(*args, **kwargs) diff --git a/barbican/openstack/common/gettextutils.py b/barbican/openstack/common/gettextutils.py index fecbeb7c0..0cfc660d5 100644 --- a/barbican/openstack/common/gettextutils.py +++ b/barbican/openstack/common/gettextutils.py @@ -1,6 +1,7 @@ # vim: tabstop=4 shiftwidth=4 softtabstop=4 # Copyright 2012 Red Hat, Inc. +# Copyright 2013 IBM Corp. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -23,18 +24,27 @@ Usual usage in an openstack.common module: from barbican.openstack.common.gettextutils import _ """ +import copy import gettext +import logging.handlers import os +import re +import UserString + +from babel import localedata +import six _localedir = os.environ.get('barbican'.upper() + '_LOCALEDIR') _t = gettext.translation('barbican', localedir=_localedir, fallback=True) +_AVAILABLE_LANGUAGES = [] + def _(msg): return _t.ugettext(msg) -def install(domain): +def install(domain, lazy=False): """Install a _() function using the given translation domain. Given a translation domain, install a _() function using gettext's @@ -44,7 +54,252 @@ def install(domain): overriding the default localedir (e.g. /usr/share/locale) using a translation-domain-specific environment variable (e.g. NOVA_LOCALEDIR). + + :param domain: the translation domain + :param lazy: indicates whether or not to install the lazy _() function. + The lazy _() introduces a way to do deferred translation + of messages by installing a _ that builds Message objects, + instead of strings, which can then be lazily translated into + any available locale. """ - gettext.install(domain, - localedir=os.environ.get(domain.upper() + '_LOCALEDIR'), - unicode=True) + if lazy: + # NOTE(mrodden): Lazy gettext functionality. + # + # The following introduces a deferred way to do translations on + # messages in OpenStack. We override the standard _() function + # and % (format string) operation to build Message objects that can + # later be translated when we have more information. + # + # Also included below is an example LocaleHandler that translates + # Messages to an associated locale, effectively allowing many logs, + # each with their own locale. + + def _lazy_gettext(msg): + """Create and return a Message object. + + Lazy gettext function for a given domain, it is a factory method + for a project/module to get a lazy gettext function for its own + translation domain (i.e. nova, glance, cinder, etc.) + + Message encapsulates a string so that we can translate + it later when needed. + """ + return Message(msg, domain) + + import __builtin__ + __builtin__.__dict__['_'] = _lazy_gettext + else: + localedir = '%s_LOCALEDIR' % domain.upper() + gettext.install(domain, + localedir=os.environ.get(localedir), + unicode=True) + + +class Message(UserString.UserString, object): + """Class used to encapsulate translatable messages.""" + def __init__(self, msg, domain): + # _msg is the gettext msgid and should never change + self._msg = msg + self._left_extra_msg = '' + self._right_extra_msg = '' + self.params = None + self.locale = None + self.domain = domain + + @property + def data(self): + # NOTE(mrodden): this should always resolve to a unicode string + # that best represents the state of the message currently + + localedir = os.environ.get(self.domain.upper() + '_LOCALEDIR') + if self.locale: + lang = gettext.translation(self.domain, + localedir=localedir, + languages=[self.locale], + fallback=True) + else: + # use system locale for translations + lang = gettext.translation(self.domain, + localedir=localedir, + fallback=True) + + full_msg = (self._left_extra_msg + + lang.ugettext(self._msg) + + self._right_extra_msg) + + if self.params is not None: + full_msg = full_msg % self.params + + return six.text_type(full_msg) + + def _save_dictionary_parameter(self, dict_param): + full_msg = self.data + # look for %(blah) fields in string; + # ignore %% and deal with the + # case where % is first character on the line + keys = re.findall('(?:[^%]|^)?%\((\w*)\)[a-z]', full_msg) + + # if we don't find any %(blah) blocks but have a %s + if not keys and re.findall('(?:[^%]|^)%[a-z]', full_msg): + # apparently the full dictionary is the parameter + params = copy.deepcopy(dict_param) + else: + params = {} + for key in keys: + try: + params[key] = copy.deepcopy(dict_param[key]) + except TypeError: + # cast uncopyable thing to unicode string + params[key] = unicode(dict_param[key]) + + return params + + def _save_parameters(self, other): + # we check for None later to see if + # we actually have parameters to inject, + # so encapsulate if our parameter is actually None + if other is None: + self.params = (other, ) + elif isinstance(other, dict): + self.params = self._save_dictionary_parameter(other) + else: + # fallback to casting to unicode, + # this will handle the problematic python code-like + # objects that cannot be deep-copied + try: + self.params = copy.deepcopy(other) + except TypeError: + self.params = unicode(other) + + return self + + # overrides to be more string-like + def __unicode__(self): + return self.data + + def __str__(self): + return self.data.encode('utf-8') + + def __getstate__(self): + to_copy = ['_msg', '_right_extra_msg', '_left_extra_msg', + 'domain', 'params', 'locale'] + new_dict = self.__dict__.fromkeys(to_copy) + for attr in to_copy: + new_dict[attr] = copy.deepcopy(self.__dict__[attr]) + + return new_dict + + def __setstate__(self, state): + for (k, v) in state.items(): + setattr(self, k, v) + + # operator overloads + def __add__(self, other): + copied = copy.deepcopy(self) + copied._right_extra_msg += other.__str__() + return copied + + def __radd__(self, other): + copied = copy.deepcopy(self) + copied._left_extra_msg += other.__str__() + return copied + + def __mod__(self, other): + # do a format string to catch and raise + # any possible KeyErrors from missing parameters + self.data % other + copied = copy.deepcopy(self) + return copied._save_parameters(other) + + def __mul__(self, other): + return self.data * other + + def __rmul__(self, other): + return other * self.data + + def __getitem__(self, key): + return self.data[key] + + def __getslice__(self, start, end): + return self.data.__getslice__(start, end) + + def __getattribute__(self, name): + # NOTE(mrodden): handle lossy operations that we can't deal with yet + # These override the UserString implementation, since UserString + # uses our __class__ attribute to try and build a new message + # after running the inner data string through the operation. + # At that point, we have lost the gettext message id and can just + # safely resolve to a string instead. + ops = ['capitalize', 'center', 'decode', 'encode', + 'expandtabs', 'ljust', 'lstrip', 'replace', 'rjust', 'rstrip', + 'strip', 'swapcase', 'title', 'translate', 'upper', 'zfill'] + if name in ops: + return getattr(self.data, name) + else: + return UserString.UserString.__getattribute__(self, name) + + +def get_available_languages(domain): + """Lists the available languages for the given translation domain. + + :param domain: the domain to get languages for + """ + if _AVAILABLE_LANGUAGES: + return _AVAILABLE_LANGUAGES + + localedir = '%s_LOCALEDIR' % domain.upper() + find = lambda x: gettext.find(domain, + localedir=os.environ.get(localedir), + languages=[x]) + + # NOTE(mrodden): en_US should always be available (and first in case + # order matters) since our in-line message strings are en_US + _AVAILABLE_LANGUAGES.append('en_US') + # NOTE(luisg): Babel <1.0 used a function called list(), which was + # renamed to locale_identifiers() in >=1.0, the requirements master list + # requires >=0.9.6, uncapped, so defensively work with both. We can remove + # this check when the master list updates to >=1.0, and all projects udpate + list_identifiers = (getattr(localedata, 'list', None) or + getattr(localedata, 'locale_identifiers')) + locale_identifiers = list_identifiers() + for i in locale_identifiers: + if find(i) is not None: + _AVAILABLE_LANGUAGES.append(i) + return _AVAILABLE_LANGUAGES + + +def get_localized_message(message, user_locale): + """Gets a localized version of the given message in the given locale.""" + if (isinstance(message, Message)): + if user_locale: + message.locale = user_locale + return unicode(message) + else: + return message + + +class LocaleHandler(logging.Handler): + """Handler that can have a locale associated to translate Messages. + + A quick example of how to utilize the Message class above. + LocaleHandler takes a locale and a target logging.Handler object + to forward LogRecord objects to after translating the internal Message. + """ + + def __init__(self, locale, target): + """Initialize a LocaleHandler + + :param locale: locale to use for translating messages + :param target: logging.Handler object to forward + LogRecord objects to after translation + """ + logging.Handler.__init__(self) + self.locale = locale + self.target = target + + def emit(self, record): + if isinstance(record.msg, Message): + # set the locale and resolve to a string + record.msg.locale = self.locale + + self.target.emit(record) diff --git a/barbican/openstack/common/importutils.py b/barbican/openstack/common/importutils.py index 3bd277f47..7a303f93f 100644 --- a/barbican/openstack/common/importutils.py +++ b/barbican/openstack/common/importutils.py @@ -24,7 +24,7 @@ import traceback def import_class(import_str): - """Returns a class from a string including module and class""" + """Returns a class from a string including module and class.""" mod_str, _sep, class_str = import_str.rpartition('.') try: __import__(mod_str) @@ -41,8 +41,9 @@ def import_object(import_str, *args, **kwargs): def import_object_ns(name_space, import_str, *args, **kwargs): - """ - Import a class and return an instance of it, first by trying + """Tries to import object from default namespace. + + Imports a class and return an instance of it, first by trying to find the class in a default namespace, then failing back to a full path if not found in the default namespace. """ diff --git a/barbican/openstack/common/jsonutils.py b/barbican/openstack/common/jsonutils.py index af3e03cd3..b93c5b5f2 100644 --- a/barbican/openstack/common/jsonutils.py +++ b/barbican/openstack/common/jsonutils.py @@ -41,6 +41,9 @@ import json import types import xmlrpclib +import netaddr +import six + from barbican.openstack.common import timeutils @@ -93,7 +96,7 @@ def to_primitive(value, convert_instances=False, convert_datetime=True, # value of itertools.count doesn't get caught by nasty_type_tests # and results in infinite loop when list(value) is called. if type(value) == itertools.count: - return unicode(value) + return six.text_type(value) # FIXME(vish): Workaround for LP bug 852095. Without this workaround, # tests that raise an exception in a mocked method that @@ -135,14 +138,16 @@ def to_primitive(value, convert_instances=False, convert_datetime=True, # Likely an instance of something. Watch for cycles. # Ignore class member vars. return recursive(value.__dict__, level=level + 1) + elif isinstance(value, netaddr.IPAddress): + return six.text_type(value) else: if any(test(value) for test in _nasty_type_tests): - return unicode(value) + return six.text_type(value) return value except TypeError: # Class objects are tricky since they may define something like # __iter__ defined but it isn't callable as list(). - return unicode(value) + return six.text_type(value) def dumps(value, default=to_primitive, **kwargs): diff --git a/barbican/openstack/common/local.py b/barbican/openstack/common/local.py index f1bfc824b..e82f17d0f 100644 --- a/barbican/openstack/common/local.py +++ b/barbican/openstack/common/local.py @@ -15,16 +15,15 @@ # License for the specific language governing permissions and limitations # under the License. -"""Greenthread local storage of variables using weak references""" +"""Local storage of variables using weak references""" +import threading import weakref -from eventlet import corolocal - -class WeakLocal(corolocal.local): +class WeakLocal(threading.local): def __getattribute__(self, attr): - rval = corolocal.local.__getattribute__(self, attr) + rval = super(WeakLocal, self).__getattribute__(attr) if rval: # NOTE(mikal): this bit is confusing. What is stored is a weak # reference, not the value itself. We therefore need to lookup @@ -34,7 +33,7 @@ class WeakLocal(corolocal.local): def __setattr__(self, attr, value): value = weakref.ref(value) - return corolocal.local.__setattr__(self, attr, value) + return super(WeakLocal, self).__setattr__(attr, value) # NOTE(mikal): the name "store" should be deprecated in the future @@ -45,4 +44,4 @@ store = WeakLocal() # "strong" store will hold a reference to the object so that it never falls out # of scope. weak_store = WeakLocal() -strong_store = corolocal.local +strong_store = threading.local() diff --git a/barbican/openstack/common/log.py b/barbican/openstack/common/log.py index 85778e544..9d22f95fd 100644 --- a/barbican/openstack/common/log.py +++ b/barbican/openstack/common/log.py @@ -29,27 +29,24 @@ It also allows setting of formatting information through conf. """ -import ConfigParser -import cStringIO import inspect import itertools import logging import logging.config import logging.handlers import os -import stat import sys import traceback from oslo.config import cfg +from six import moves -from barbican.openstack.common.gettextutils import _ +from barbican.openstack.common.gettextutils import _ # noqa +from barbican.openstack.common import importutils from barbican.openstack.common import jsonutils from barbican.openstack.common import local -from barbican.openstack.common import notifier -_DEFAULT_LOG_FORMAT = "%(asctime)s %(levelname)8s [%(name)s] %(message)s" _DEFAULT_LOG_DATE_FORMAT = "%Y-%m-%d %H:%M:%S" common_cli_opts = [ @@ -74,11 +71,14 @@ logging_cli_opts = [ 'documentation for details on logging configuration ' 'files.'), cfg.StrOpt('log-format', - default=_DEFAULT_LOG_FORMAT, + default=None, metavar='FORMAT', - help='A logging.Formatter log message format string which may ' + help='DEPRECATED. ' + 'A logging.Formatter log message format string which may ' 'use any of the available logging.LogRecord attributes. ' - 'Default: %(default)s'), + 'This option is deprecated. Please use ' + 'logging_context_format_string and ' + 'logging_default_format_string instead.'), cfg.StrOpt('log-date-format', default=_DEFAULT_LOG_DATE_FORMAT, metavar='DATE_FORMAT', @@ -104,10 +104,7 @@ logging_cli_opts = [ generic_log_opts = [ cfg.BoolOpt('use_stderr', default=True, - help='Log output to standard error'), - cfg.StrOpt('logfile_mode', - default='0644', - help='Default file mode used when creating log files'), + help='Log output to standard error') ] log_opts = [ @@ -211,7 +208,27 @@ def _get_log_file_path(binary=None): return '%s.log' % (os.path.join(logdir, binary),) -class ContextAdapter(logging.LoggerAdapter): +class BaseLoggerAdapter(logging.LoggerAdapter): + + def audit(self, msg, *args, **kwargs): + self.log(logging.AUDIT, msg, *args, **kwargs) + + +class LazyAdapter(BaseLoggerAdapter): + def __init__(self, name='unknown', version='unknown'): + self._logger = None + self.extra = {} + self.name = name + self.version = version + + @property + def logger(self): + if not self._logger: + self._logger = getLogger(self.name, self.version) + return self._logger + + +class ContextAdapter(BaseLoggerAdapter): warn = logging.LoggerAdapter.warning def __init__(self, logger, project_name, version_string): @@ -219,8 +236,9 @@ class ContextAdapter(logging.LoggerAdapter): self.project = project_name self.version = version_string - def audit(self, msg, *args, **kwargs): - self.log(logging.AUDIT, msg, *args, **kwargs) + @property + def handlers(self): + return self.logger.handlers def deprecated(self, msg, *args, **kwargs): stdmsg = _("Deprecated: %s") % msg @@ -304,17 +322,6 @@ class JSONFormatter(logging.Formatter): return jsonutils.dumps(message) -class PublishErrorsHandler(logging.Handler): - def emit(self, record): - if ('barbican.openstack.common.notifier.log_notifier' in - CONF.notification_driver): - return - notifier.api.notify(None, 'error.publisher', - 'error_notification', - notifier.api.ERROR, - dict(error=record.msg)) - - def _create_logging_excepthook(product_name): def logging_excepthook(type, value, tb): extra = {} @@ -340,7 +347,7 @@ class LogConfigError(Exception): def _load_log_config(log_config): try: logging.config.fileConfig(log_config) - except ConfigParser.Error, exc: + except moves.configparser.Error as exc: raise LogConfigError(log_config, str(exc)) @@ -399,11 +406,6 @@ def _setup_logging_from_conf(): filelog = logging.handlers.WatchedFileHandler(logpath) log_root.addHandler(filelog) - mode = int(CONF.logfile_mode, 8) - st = os.stat(logpath) - if st.st_mode != (stat.S_IFREG | mode): - os.chmod(logpath, mode) - if CONF.use_stderr: streamlog = ColorHandler() log_root.addHandler(streamlog) @@ -415,15 +417,22 @@ def _setup_logging_from_conf(): log_root.addHandler(streamlog) if CONF.publish_errors: - log_root.addHandler(PublishErrorsHandler(logging.ERROR)) + handler = importutils.import_object( + "barbican.openstack.common.log_handler.PublishErrorsHandler", + logging.ERROR) + log_root.addHandler(handler) + datefmt = CONF.log_date_format for handler in log_root.handlers: - datefmt = CONF.log_date_format + # NOTE(alaski): CONF.log_format overrides everything currently. This + # should be deprecated in favor of context aware formatting. if CONF.log_format: handler.setFormatter(logging.Formatter(fmt=CONF.log_format, datefmt=datefmt)) + log_root.info('Deprecated: log_format is now deprecated and will ' + 'be removed in the next release') else: - handler.setFormatter(LegacyFormatter(datefmt=datefmt)) + handler.setFormatter(ContextFormatter(datefmt=datefmt)) if CONF.debug: log_root.setLevel(logging.DEBUG) @@ -432,14 +441,11 @@ def _setup_logging_from_conf(): else: log_root.setLevel(logging.WARNING) - level = logging.NOTSET for pair in CONF.default_log_levels: mod, _sep, level_name = pair.partition('=') level = logging.getLevelName(level_name) logger = logging.getLogger(mod) logger.setLevel(level) - for handler in log_root.handlers: - logger.addHandler(handler) _loggers = {} @@ -452,6 +458,16 @@ def getLogger(name='unknown', version='unknown'): return _loggers[name] +def getLazyLogger(name='unknown', version='unknown'): + """Returns lazy logger. + + Creates a pass-through logger that does not create the real logger + until it is really needed and delegates all calls to the real logger + once it is created. + """ + return LazyAdapter(name, version) + + class WritableLogger(object): """A thin wrapper that responds to `write` and logs.""" @@ -463,7 +479,7 @@ class WritableLogger(object): self.logger.log(self.level, msg) -class LegacyFormatter(logging.Formatter): +class ContextFormatter(logging.Formatter): """A context.RequestContext aware formatter configured through flags. The flags used to set format strings are: logging_context_format_string @@ -504,7 +520,7 @@ class LegacyFormatter(logging.Formatter): if not record: return logging.Formatter.formatException(self, exc_info) - stringbuffer = cStringIO.StringIO() + stringbuffer = moves.StringIO() traceback.print_exception(exc_info[0], exc_info[1], exc_info[2], None, stringbuffer) lines = stringbuffer.getvalue().split('\n') diff --git a/barbican/openstack/common/loopingcall.py b/barbican/openstack/common/loopingcall.py index 1917e6326..4bbbe18f6 100644 --- a/barbican/openstack/common/loopingcall.py +++ b/barbican/openstack/common/loopingcall.py @@ -22,7 +22,7 @@ import sys from eventlet import event from eventlet import greenthread -from barbican.openstack.common.gettextutils import _ +from barbican.openstack.common.gettextutils import _ # noqa from barbican.openstack.common import log as logging from barbican.openstack.common import timeutils @@ -84,7 +84,7 @@ class FixedIntervalLoopingCall(LoopingCallBase): LOG.warn(_('task run outlasted interval by %s sec') % -delay) greenthread.sleep(delay if delay > 0 else 0) - except LoopingCallDone, e: + except LoopingCallDone as e: self.stop() done.send(e.retvalue) except Exception: @@ -131,7 +131,7 @@ class DynamicLoopingCall(LoopingCallBase): LOG.debug(_('Dynamic looping call sleeping for %.02f ' 'seconds'), idle) greenthread.sleep(idle) - except LoopingCallDone, e: + except LoopingCallDone as e: self.stop() done.send(e.retvalue) except Exception: diff --git a/barbican/openstack/common/network_utils.py b/barbican/openstack/common/network_utils.py index 5224e01aa..dbed1ceb4 100644 --- a/barbican/openstack/common/network_utils.py +++ b/barbican/openstack/common/network_utils.py @@ -19,14 +19,12 @@ Network-related utilities and helper functions. """ -import logging - -LOG = logging.getLogger(__name__) +import urlparse def parse_host_port(address, default_port=None): - """ - Interpret a string as a host:port pair. + """Interpret a string as a host:port pair. + An IPv6 address MUST be escaped if accompanied by a port, because otherwise ambiguity ensues: 2001:db8:85a3::8a2e:370:7334 means both [2001:db8:85a3::8a2e:370:7334] and @@ -66,3 +64,18 @@ def parse_host_port(address, default_port=None): port = default_port return (host, None if port is None else int(port)) + + +def urlsplit(url, scheme='', allow_fragments=True): + """Parse a URL using urlparse.urlsplit(), splitting query and fragments. + This function papers over Python issue9374 when needed. + + The parameters are the same as urlparse.urlsplit. + """ + scheme, netloc, path, query, fragment = urlparse.urlsplit( + url, scheme, allow_fragments) + if allow_fragments and '#' in path: + path, fragment = path.split('#', 1) + if '?' in path: + path, query = path.split('?', 1) + return urlparse.SplitResult(scheme, netloc, path, query, fragment) diff --git a/barbican/openstack/common/notifier/api.py b/barbican/openstack/common/notifier/api.py index 96963c148..495f5abdd 100644 --- a/barbican/openstack/common/notifier/api.py +++ b/barbican/openstack/common/notifier/api.py @@ -13,12 +13,13 @@ # License for the specific language governing permissions and limitations # under the License. +import socket import uuid from oslo.config import cfg from barbican.openstack.common import context -from barbican.openstack.common.gettextutils import _ +from barbican.openstack.common.gettextutils import _ # noqa from barbican.openstack.common import importutils from barbican.openstack.common import jsonutils from barbican.openstack.common import log as logging @@ -35,7 +36,7 @@ notifier_opts = [ default='INFO', help='Default notification level for outgoing notifications'), cfg.StrOpt('default_publisher_id', - default='$host', + default=None, help='Default publisher_id for outgoing notifications'), ] @@ -56,7 +57,7 @@ class BadPriorityException(Exception): def notify_decorator(name, fn): - """ decorator for notify which is used from utils.monkey_patch() + """Decorator for notify which is used from utils.monkey_patch(). :param name: name of the function :param function: - object of the function @@ -74,7 +75,7 @@ def notify_decorator(name, fn): ctxt = context.get_context_from_function_and_args(fn, args, kwarg) notify(ctxt, - CONF.default_publisher_id, + CONF.default_publisher_id or socket.gethostname(), name, CONF.default_notification_level, body) @@ -84,7 +85,10 @@ def notify_decorator(name, fn): def publisher_id(service, host=None): if not host: - host = CONF.host + try: + host = CONF.host + except AttributeError: + host = CONF.default_publisher_id or socket.gethostname() return "%s.%s" % (service, host) @@ -153,29 +157,16 @@ def _get_drivers(): if _drivers is None: _drivers = {} for notification_driver in CONF.notification_driver: - add_driver(notification_driver) - + try: + driver = importutils.import_module(notification_driver) + _drivers[notification_driver] = driver + except ImportError: + LOG.exception(_("Failed to load notifier %s. " + "These notifications will not be sent.") % + notification_driver) return _drivers.values() -def add_driver(notification_driver): - """Add a notification driver at runtime.""" - # Make sure the driver list is initialized. - _get_drivers() - if isinstance(notification_driver, basestring): - # Load and add - try: - driver = importutils.import_module(notification_driver) - _drivers[notification_driver] = driver - except ImportError: - LOG.exception(_("Failed to load notifier %s. " - "These notifications will not be sent.") % - notification_driver) - else: - # Driver is already loaded; just add the object. - _drivers[notification_driver] = notification_driver - - def _reset_drivers(): """Used by unit tests to reset the drivers.""" global _drivers diff --git a/barbican/openstack/common/notifier/log_notifier.py b/barbican/openstack/common/notifier/log_notifier.py index 9d33ef957..db6866335 100644 --- a/barbican/openstack/common/notifier/log_notifier.py +++ b/barbican/openstack/common/notifier/log_notifier.py @@ -24,7 +24,9 @@ CONF = cfg.CONF def notify(_context, message): """Notifies the recipient of the desired event given the model. - Log notifications using openstack's default logging system""" + + Log notifications using OpenStack's default logging system. + """ priority = message.get('priority', CONF.default_notification_level) diff --git a/barbican/openstack/common/notifier/no_op_notifier.py b/barbican/openstack/common/notifier/no_op_notifier.py index bc7a56ca7..13d946e36 100644 --- a/barbican/openstack/common/notifier/no_op_notifier.py +++ b/barbican/openstack/common/notifier/no_op_notifier.py @@ -15,5 +15,5 @@ def notify(_context, message): - """Notifies the recipient of the desired event given the model""" + """Notifies the recipient of the desired event given the model.""" pass diff --git a/barbican/openstack/common/notifier/rpc_notifier.py b/barbican/openstack/common/notifier/rpc_notifier.py index 0775bac46..a15229992 100644 --- a/barbican/openstack/common/notifier/rpc_notifier.py +++ b/barbican/openstack/common/notifier/rpc_notifier.py @@ -16,7 +16,7 @@ from oslo.config import cfg from barbican.openstack.common import context as req_context -from barbican.openstack.common.gettextutils import _ +from barbican.openstack.common.gettextutils import _ # noqa from barbican.openstack.common import log as logging from barbican.openstack.common import rpc @@ -24,14 +24,14 @@ LOG = logging.getLogger(__name__) notification_topic_opt = cfg.ListOpt( 'notification_topics', default=['notifications', ], - help='AMQP topic used for openstack notifications') + help='AMQP topic used for OpenStack notifications') CONF = cfg.CONF CONF.register_opt(notification_topic_opt) def notify(context, message): - """Sends a notification via RPC""" + """Sends a notification via RPC.""" if not context: context = req_context.get_admin_context() priority = message.get('priority', diff --git a/barbican/openstack/common/notifier/rpc_notifier2.py b/barbican/openstack/common/notifier/rpc_notifier2.py index 736fb6869..0035aca3b 100644 --- a/barbican/openstack/common/notifier/rpc_notifier2.py +++ b/barbican/openstack/common/notifier/rpc_notifier2.py @@ -18,7 +18,7 @@ from oslo.config import cfg from barbican.openstack.common import context as req_context -from barbican.openstack.common.gettextutils import _ +from barbican.openstack.common.gettextutils import _ # noqa from barbican.openstack.common import log as logging from barbican.openstack.common import rpc @@ -26,7 +26,7 @@ LOG = logging.getLogger(__name__) notification_topic_opt = cfg.ListOpt( 'topics', default=['notifications', ], - help='AMQP topic(s) used for openstack notifications') + help='AMQP topic(s) used for OpenStack notifications') opt_group = cfg.OptGroup(name='rpc_notifier2', title='Options for rpc_notifier2') @@ -37,7 +37,7 @@ CONF.register_opt(notification_topic_opt, opt_group) def notify(context, message): - """Sends a notification via RPC""" + """Sends a notification via RPC.""" if not context: context = req_context.get_admin_context() priority = message.get('priority', diff --git a/barbican/openstack/common/policy.py b/barbican/openstack/common/policy.py index bef488c3c..e2275c6f0 100644 --- a/barbican/openstack/common/policy.py +++ b/barbican/openstack/common/policy.py @@ -1,6 +1,6 @@ # vim: tabstop=4 shiftwidth=4 softtabstop=4 -# Copyright (c) 2012 OpenStack, LLC. +# Copyright (c) 2012 OpenStack Foundation. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -57,33 +57,48 @@ as it allows particular rules to be explicitly disabled. """ import abc -import logging import re import urllib - import urllib2 -from barbican.openstack.common.gettextutils import _ -from barbican.openstack.common import jsonutils +from oslo.config import cfg +import six +from barbican.openstack.common import fileutils +from barbican.openstack.common.gettextutils import _ # noqa +from barbican.openstack.common import jsonutils +from barbican.openstack.common import log as logging + +policy_opts = [ + cfg.StrOpt('policy_file', + default='policy.json', + help=_('JSON file containing policy')), + cfg.StrOpt('policy_default_rule', + default='default', + help=_('Rule enforced when requested rule is not found')), +] + +CONF = cfg.CONF +CONF.register_opts(policy_opts) LOG = logging.getLogger(__name__) - -_rules = None _checks = {} +class PolicyNotAuthorized(Exception): + + def __init__(self, rule): + msg = _("Policy doesn't allow %s to be performed.") % rule + super(PolicyNotAuthorized, self).__init__(msg) + + class Rules(dict): - """ - A store for rules. Handles the default_rule setting directly. - """ + """A store for rules. Handles the default_rule setting directly.""" @classmethod def load_json(cls, data, default_rule=None): - """ - Allow loading of JSON rule data. - """ + """Allow loading of JSON rule data.""" # Suck in the JSON data and parse the rules rules = dict((k, parse_rule(v)) for k, v in @@ -100,12 +115,18 @@ class Rules(dict): def __missing__(self, key): """Implements the default rule handling.""" + if isinstance(self.default_rule, dict): + raise KeyError(key) + # If the default rule isn't actually defined, do something # reasonably intelligent if not self.default_rule or self.default_rule not in self: raise KeyError(key) - return self[self.default_rule] + if isinstance(self.default_rule, BaseCheck): + return self.default_rule + elif isinstance(self.default_rule, six.string_types): + return self[self.default_rule] def __str__(self): """Dumps a string representation of the rules.""" @@ -123,87 +144,157 @@ class Rules(dict): return jsonutils.dumps(out_rules, indent=4) -# Really have to figure out a way to deprecate this -def set_rules(rules): - """Set the rules in use for policy checks.""" +class Enforcer(object): + """Responsible for loading and enforcing rules. - global _rules - - _rules = rules - - -# Ditto -def reset(): - """Clear the rules used for policy checks.""" - - global _rules - - _rules = None - - -def check(rule, target, creds, exc=None, *args, **kwargs): - """ - Checks authorization of a rule against the target and credentials. - - :param rule: The rule to evaluate. - :param target: As much information about the object being operated - on as possible, as a dictionary. - :param creds: As much information about the user performing the - action as possible, as a dictionary. - :param exc: Class of the exception to raise if the check fails. - Any remaining arguments passed to check() (both - positional and keyword arguments) will be passed to - the exception class. If exc is not provided, returns - False. - - :return: Returns False if the policy does not allow the action and - exc is not provided; otherwise, returns a value that - evaluates to True. Note: for rules using the "case" - expression, this True value will be the specified string - from the expression. + :param policy_file: Custom policy file to use, if none is + specified, `CONF.policy_file` will be + used. + :param rules: Default dictionary / Rules to use. It will be + considered just in the first instantiation. If + `load_rules(True)`, `clear()` or `set_rules(True)` + is called this will be overwritten. + :param default_rule: Default rule to use, CONF.default_rule will + be used if none is specified. """ - # Allow the rule to be a Check tree - if isinstance(rule, BaseCheck): - result = rule(target, creds) - elif not _rules: - # No rules to reference means we're going to fail closed - result = False - else: - try: - # Evaluate the rule - result = _rules[rule](target, creds) - except KeyError: - # If the rule doesn't exist, fail closed + def __init__(self, policy_file=None, rules=None, default_rule=None): + self.rules = Rules(rules, default_rule) + self.default_rule = default_rule or CONF.policy_default_rule + + self.policy_path = None + self.policy_file = policy_file or CONF.policy_file + + def set_rules(self, rules, overwrite=True): + """Create a new Rules object based on the provided dict of rules. + + :param rules: New rules to use. It should be an instance of dict. + :param overwrite: Whether to overwrite current rules or update them + with the new rules. + """ + + if not isinstance(rules, dict): + raise TypeError(_("Rules must be an instance of dict or Rules, " + "got %s instead") % type(rules)) + + if overwrite: + self.rules = Rules(rules, self.default_rule) + else: + self.rules.update(rules) + + def clear(self): + """Clears Enforcer rules, policy's cache and policy's path.""" + self.set_rules({}) + self.default_rule = None + self.policy_path = None + + def load_rules(self, force_reload=False): + """Loads policy_path's rules. + + Policy file is cached and will be reloaded if modified. + + :param force_reload: Whether to overwrite current rules. + """ + + if not self.policy_path: + self.policy_path = self._get_policy_path() + + reloaded, data = fileutils.read_cached_file(self.policy_path, + force_reload=force_reload) + if reloaded or not self.rules: + rules = Rules.load_json(data, self.default_rule) + self.set_rules(rules) + LOG.debug(_("Rules successfully reloaded")) + + def _get_policy_path(self): + """Locate the policy json data file. + + :param policy_file: Custom policy file to locate. + + :returns: The policy path + + :raises: ConfigFilesNotFoundError if the file couldn't + be located. + """ + policy_file = CONF.find_file(self.policy_file) + + if policy_file: + return policy_file + + raise cfg.ConfigFilesNotFoundError(path=CONF.policy_file) + + def enforce(self, rule, target, creds, do_raise=False, + exc=None, *args, **kwargs): + """Checks authorization of a rule against the target and credentials. + + :param rule: A string or BaseCheck instance specifying the rule + to evaluate. + :param target: As much information about the object being operated + on as possible, as a dictionary. + :param creds: As much information about the user performing the + action as possible, as a dictionary. + :param do_raise: Whether to raise an exception or not if check + fails. + :param exc: Class of the exception to raise if the check fails. + Any remaining arguments passed to check() (both + positional and keyword arguments) will be passed to + the exception class. If not specified, PolicyNotAuthorized + will be used. + + :return: Returns False if the policy does not allow the action and + exc is not provided; otherwise, returns a value that + evaluates to True. Note: for rules using the "case" + expression, this True value will be the specified string + from the expression. + """ + + # NOTE(flaper87): Not logging target or creds to avoid + # potential security issues. + LOG.debug(_("Rule %s will be now enforced") % rule) + + self.load_rules() + + # Allow the rule to be a Check tree + if isinstance(rule, BaseCheck): + result = rule(target, creds, self) + elif not self.rules: + # No rules to reference means we're going to fail closed result = False + else: + try: + # Evaluate the rule + result = self.rules[rule](target, creds, self) + except KeyError: + LOG.debug(_("Rule [%s] doesn't exist") % rule) + # If the rule doesn't exist, fail closed + result = False - # If it is False, raise the exception if requested - if exc and result is False: - raise exc(*args, **kwargs) + # If it is False, raise the exception if requested + if do_raise and not result: + if exc: + raise exc(*args, **kwargs) - return result + raise PolicyNotAuthorized(rule) + + return result class BaseCheck(object): - """ - Abstract base class for Check classes. - """ + """Abstract base class for Check classes.""" __metaclass__ = abc.ABCMeta @abc.abstractmethod def __str__(self): - """ - Retrieve a string representation of the Check tree rooted at - this node. - """ + """String representation of the Check tree rooted at this node.""" pass @abc.abstractmethod - def __call__(self, target, cred): - """ - Perform the check. Returns False to reject the access or a + def __call__(self, target, cred, enforcer): + """Triggers if instance of the class is called. + + Performs the check. Returns False to reject the access or a true value (not necessary True) to accept the access. """ @@ -211,44 +302,39 @@ class BaseCheck(object): class FalseCheck(BaseCheck): - """ - A policy check that always returns False (disallow). - """ + """A policy check that always returns False (disallow).""" def __str__(self): """Return a string representation of this check.""" return "!" - def __call__(self, target, cred): + def __call__(self, target, cred, enforcer): """Check the policy.""" return False class TrueCheck(BaseCheck): - """ - A policy check that always returns True (allow). - """ + """A policy check that always returns True (allow).""" def __str__(self): """Return a string representation of this check.""" return "@" - def __call__(self, target, cred): + def __call__(self, target, cred, enforcer): """Check the policy.""" return True class Check(BaseCheck): - """ - A base class to allow for user-defined policy checks. - """ + """A base class to allow for user-defined policy checks.""" def __init__(self, kind, match): - """ + """Initiates Check instance. + :param kind: The kind of the check, i.e., the field before the ':'. :param match: The match of the check, i.e., the field after @@ -265,14 +351,13 @@ class Check(BaseCheck): class NotCheck(BaseCheck): - """ + """Implements the "not" logical operator. + A policy check that inverts the result of another policy check. - Implements the "not" operator. """ def __init__(self, rule): - """ - Initialize the 'not' check. + """Initialize the 'not' check. :param rule: The rule to negate. Must be a Check. """ @@ -284,24 +369,23 @@ class NotCheck(BaseCheck): return "not %s" % self.rule - def __call__(self, target, cred): - """ - Check the policy. Returns the logical inverse of the wrapped - check. + def __call__(self, target, cred, enforcer): + """Check the policy. + + Returns the logical inverse of the wrapped check. """ - return not self.rule(target, cred) + return not self.rule(target, cred, enforcer) class AndCheck(BaseCheck): - """ - A policy check that requires that a list of other checks all - return True. Implements the "and" operator. + """Implements the "and" logical operator. + + A policy check that requires that a list of other checks all return True. """ def __init__(self, rules): - """ - Initialize the 'and' check. + """Initialize the 'and' check. :param rules: A list of rules that will be tested. """ @@ -313,20 +397,21 @@ class AndCheck(BaseCheck): return "(%s)" % ' and '.join(str(r) for r in self.rules) - def __call__(self, target, cred): - """ - Check the policy. Requires that all rules accept in order to - return True. + def __call__(self, target, cred, enforcer): + """Check the policy. + + Requires that all rules accept in order to return True. """ for rule in self.rules: - if not rule(target, cred): + if not rule(target, cred, enforcer): return False return True def add_check(self, rule): - """ + """Adds rule to be tested. + Allows addition of another rule to the list of rules that will be tested. Returns the AndCheck object for convenience. """ @@ -336,14 +421,14 @@ class AndCheck(BaseCheck): class OrCheck(BaseCheck): - """ + """Implements the "or" operator. + A policy check that requires that at least one of a list of other - checks returns True. Implements the "or" operator. + checks returns True. """ def __init__(self, rules): - """ - Initialize the 'or' check. + """Initialize the 'or' check. :param rules: A list of rules that will be tested. """ @@ -355,20 +440,21 @@ class OrCheck(BaseCheck): return "(%s)" % ' or '.join(str(r) for r in self.rules) - def __call__(self, target, cred): - """ - Check the policy. Requires that at least one rule accept in - order to return True. + def __call__(self, target, cred, enforcer): + """Check the policy. + + Requires that at least one rule accept in order to return True. """ for rule in self.rules: - if rule(target, cred): + if rule(target, cred, enforcer): return True return False def add_check(self, rule): - """ + """Adds rule to be tested. + Allows addition of another rule to the list of rules that will be tested. Returns the OrCheck object for convenience. """ @@ -378,9 +464,6 @@ class OrCheck(BaseCheck): def _parse_check(rule): - """ - Parse a single base check rule into an appropriate Check object. - """ # Handle the special checks if rule == '!': @@ -391,7 +474,7 @@ def _parse_check(rule): try: kind, match = rule.split(':', 1) except Exception: - LOG.exception(_("Failed to understand rule %(rule)s") % locals()) + LOG.exception(_("Failed to understand rule %s") % rule) # If the rule is invalid, we'll fail closed return FalseCheck() @@ -406,9 +489,9 @@ def _parse_check(rule): def _parse_list_rule(rule): - """ - Provided for backwards compatibility. Translates the old - list-of-lists syntax into a tree of Check objects. + """Translates the old list-of-lists syntax into a tree of Check objects. + + Provided for backwards compatibility. """ # Empty rule defaults to True @@ -436,7 +519,7 @@ def _parse_list_rule(rule): or_list.append(AndCheck(and_list)) # If we have only one check, omit the "or" - if len(or_list) == 0: + if not or_list: return FalseCheck() elif len(or_list) == 1: return or_list[0] @@ -449,8 +532,7 @@ _tokenize_re = re.compile(r'\s+') def _parse_tokenize(rule): - """ - Tokenizer for the policy language. + """Tokenizer for the policy language. Most of the single-character tokens are specified in the _tokenize_re; however, parentheses need to be handled specially, @@ -499,16 +581,16 @@ def _parse_tokenize(rule): class ParseStateMeta(type): - """ - Metaclass for the ParseState class. Facilitates identifying - reduction methods. + """Metaclass for the ParseState class. + + Facilitates identifying reduction methods. """ def __new__(mcs, name, bases, cls_dict): - """ - Create the class. Injects the 'reducers' list, a list of - tuples matching token sequences to the names of the - corresponding reduction methods. + """Create the class. + + Injects the 'reducers' list, a list of tuples matching token sequences + to the names of the corresponding reduction methods. """ reducers = [] @@ -525,10 +607,10 @@ class ParseStateMeta(type): def reducer(*tokens): - """ - Decorator for reduction methods. Arguments are a sequence of - tokens, in order, which should trigger running this reduction - method. + """Decorator for reduction methods. + + Arguments are a sequence of tokens, in order, which should trigger running + this reduction method. """ def decorator(func): @@ -545,10 +627,10 @@ def reducer(*tokens): class ParseState(object): - """ - Implement the core of parsing the policy language. Uses a greedy - reduction algorithm to reduce a sequence of tokens into a single - terminal, the value of which will be the root of the Check tree. + """Implement the core of parsing the policy language. + + Uses a greedy reduction algorithm to reduce a sequence of tokens into + a single terminal, the value of which will be the root of the Check tree. Note: error reporting is rather lacking. The best we can get with this parser formulation is an overall "parse failed" error. @@ -565,11 +647,11 @@ class ParseState(object): self.values = [] def reduce(self): - """ - Perform a greedy reduction of the token stream. If a reducer - method matches, it will be executed, then the reduce() method - will be called recursively to search for any more possible - reductions. + """Perform a greedy reduction of the token stream. + + If a reducer method matches, it will be executed, then the + reduce() method will be called recursively to search for any more + possible reductions. """ for reduction, methname in self.reducers: @@ -599,9 +681,9 @@ class ParseState(object): @property def result(self): - """ - Obtain the final result of the parse. Raises ValueError if - the parse failed to reduce to a single result. + """Obtain the final result of the parse. + + Raises ValueError if the parse failed to reduce to a single result. """ if len(self.values) != 1: @@ -618,35 +700,31 @@ class ParseState(object): @reducer('check', 'and', 'check') def _make_and_expr(self, check1, _and, check2): - """ - Create an 'and_expr' from two checks joined by the 'and' - operator. + """Create an 'and_expr'. + + Join two checks by the 'and' operator. """ return [('and_expr', AndCheck([check1, check2]))] @reducer('and_expr', 'and', 'check') def _extend_and_expr(self, and_expr, _and, check): - """ - Extend an 'and_expr' by adding one more check. - """ + """Extend an 'and_expr' by adding one more check.""" return [('and_expr', and_expr.add_check(check))] @reducer('check', 'or', 'check') def _make_or_expr(self, check1, _or, check2): - """ - Create an 'or_expr' from two checks joined by the 'or' - operator. + """Create an 'or_expr'. + + Join two checks by the 'or' operator. """ return [('or_expr', OrCheck([check1, check2]))] @reducer('or_expr', 'or', 'check') def _extend_or_expr(self, or_expr, _or, check): - """ - Extend an 'or_expr' by adding one more check. - """ + """Extend an 'or_expr' by adding one more check.""" return [('or_expr', or_expr.add_check(check))] @@ -658,7 +736,8 @@ class ParseState(object): def _parse_text_rule(rule): - """ + """Parses policy to the tree. + Translates a policy written in the policy language into a tree of Check objects. """ @@ -676,16 +755,14 @@ def _parse_text_rule(rule): return state.result except ValueError: # Couldn't parse the rule - LOG.exception(_("Failed to understand rule %(rule)r") % locals()) + LOG.exception(_("Failed to understand rule %r") % rule) # Fail closed return FalseCheck() def parse_rule(rule): - """ - Parses a policy rule into a tree of Check objects. - """ + """Parses a policy rule into a tree of Check objects.""" # If the rule is a string, it's in the policy language if isinstance(rule, basestring): @@ -694,8 +771,7 @@ def parse_rule(rule): def register(name, func=None): - """ - Register a function or Check class as a policy check. + """Register a function or Check class as a policy check. :param name: Gives the name of the check type, e.g., 'rule', 'role', etc. If name is None, a default check type @@ -722,13 +798,11 @@ def register(name, func=None): @register("rule") class RuleCheck(Check): - def __call__(self, target, creds): - """ - Recursively checks credentials based on the defined rules. - """ + def __call__(self, target, creds, enforcer): + """Recursively checks credentials based on the defined rules.""" try: - return _rules[self.match](target, creds) + return enforcer.rules[self.match](target, creds, enforcer) except KeyError: # We don't have any matching rule; fail closed return False @@ -736,17 +810,15 @@ class RuleCheck(Check): @register("role") class RoleCheck(Check): - def __call__(self, target, creds): + def __call__(self, target, creds, enforcer): """Check that there is a matching role in the cred dict.""" - return self.match.lower() in [x.lower() for x in creds['roles']] @register('http') class HttpCheck(Check): - def __call__(self, target, creds): - """ - Check http: rules by calling to a remote server. + def __call__(self, target, creds, enforcer): + """Check http: rules by calling to a remote server. This example implementation simply verifies that the response is exactly 'True'. @@ -762,9 +834,8 @@ class HttpCheck(Check): @register(None) class GenericCheck(Check): - def __call__(self, target, creds): - """ - Check an individual match. + def __call__(self, target, creds, enforcer): + """Check an individual match. Matches look like: @@ -775,5 +846,5 @@ class GenericCheck(Check): # TODO(termie): do dict inspection via dot syntax match = self.match % target if self.kind in creds: - return match == unicode(creds[self.kind]) + return match == six.text_type(creds[self.kind]) return False diff --git a/barbican/openstack/common/rpc/__init__.py b/barbican/openstack/common/rpc/__init__.py index 9d05c1b6b..ea205aac2 100644 --- a/barbican/openstack/common/rpc/__init__.py +++ b/barbican/openstack/common/rpc/__init__.py @@ -26,13 +26,13 @@ For some wrappers that add message versioning to rpc, see: """ import inspect -import logging from oslo.config import cfg -from barbican.openstack.common.gettextutils import _ +from barbican.openstack.common.gettextutils import _ # noqa from barbican.openstack.common import importutils from barbican.openstack.common import local +from barbican.openstack.common import log as logging LOG = logging.getLogger(__name__) diff --git a/barbican/openstack/common/rpc/amqp.py b/barbican/openstack/common/rpc/amqp.py index 6a734b9f4..78f6e7a28 100644 --- a/barbican/openstack/common/rpc/amqp.py +++ b/barbican/openstack/common/rpc/amqp.py @@ -34,24 +34,24 @@ from eventlet import greenpool from eventlet import pools from eventlet import queue from eventlet import semaphore -# TODO(pekowsk): Remove import cfg and below comment in Havana. -# This import should no longer be needed when the amqp_rpc_single_reply_queue -# option is removed. from oslo.config import cfg from barbican.openstack.common import excutils -from barbican.openstack.common.gettextutils import _ +from barbican.openstack.common.gettextutils import _ # noqa from barbican.openstack.common import local from barbican.openstack.common import log as logging from barbican.openstack.common.rpc import common as rpc_common -# TODO(pekowski): Remove this option in Havana. amqp_opts = [ - cfg.BoolOpt('amqp_rpc_single_reply_queue', + cfg.BoolOpt('amqp_durable_queues', default=False, - help='Enable a fast single reply queue if using AMQP based ' - 'RPC like RabbitMQ or Qpid.'), + deprecated_name='rabbit_durable_queues', + deprecated_group='DEFAULT', + help='Use durable queues in amqp.'), + cfg.BoolOpt('amqp_auto_delete', + default=False, + help='Auto-delete queues in amqp.'), ] cfg.CONF.register_opts(amqp_opts) @@ -83,7 +83,7 @@ class Pool(pools.Pool): # is the above "while loop" gets all the cached connections from the # pool and closes them, but never returns them to the pool, a pool # leak. The unit tests hang waiting for an item to be returned to the - # pool. The unit tests get here via the teatDown() method. In the run + # pool. The unit tests get here via the tearDown() method. In the run # time code, it gets here via cleanup() and only appears in service.py # just before doing a sys.exit(), so cleanup() only happens once and # the leakage is not a problem. @@ -102,19 +102,19 @@ def get_connection_pool(conf, connection_cls): class ConnectionContext(rpc_common.Connection): - """The class that is actually returned to the caller of - create_connection(). This is essentially a wrapper around - Connection that supports 'with'. It can also return a new - Connection, or one from a pool. The function will also catch - when an instance of this class is to be deleted. With that - we can return Connections to the pool on exceptions and so - forth without making the caller be responsible for catching - them. If possible the function makes sure to return a - connection to the pool. + """The class that is actually returned to the create_connection() caller. + + This is essentially a wrapper around Connection that supports 'with'. + It can also return a new Connection, or one from a pool. + + The function will also catch when an instance of this class is to be + deleted. With that we can return Connections to the pool on exceptions + and so forth without making the caller be responsible for catching them. + If possible the function makes sure to return a connection to the pool. """ def __init__(self, conf, connection_pool, pooled=True, server_params=None): - """Create a new connection, or get one from the pool""" + """Create a new connection, or get one from the pool.""" self.connection = None self.conf = conf self.connection_pool = connection_pool @@ -127,7 +127,7 @@ class ConnectionContext(rpc_common.Connection): self.pooled = pooled def __enter__(self): - """When with ConnectionContext() is used, return self""" + """When with ConnectionContext() is used, return self.""" return self def _done(self): @@ -165,17 +165,19 @@ class ConnectionContext(rpc_common.Connection): def create_worker(self, topic, proxy, pool_name): self.connection.create_worker(topic, proxy, pool_name) - def join_consumer_pool(self, callback, pool_name, topic, exchange_name): + def join_consumer_pool(self, callback, pool_name, topic, exchange_name, + ack_on_error=True): self.connection.join_consumer_pool(callback, pool_name, topic, - exchange_name) + exchange_name, + ack_on_error) def consume_in_thread(self): self.connection.consume_in_thread() def __getattr__(self, key): - """Proxy all other calls to the Connection instance""" + """Proxy all other calls to the Connection instance.""" if self.connection: return getattr(self.connection, key) else: @@ -183,7 +185,7 @@ class ConnectionContext(rpc_common.Connection): class ReplyProxy(ConnectionContext): - """ Connection class for RPC replies / callbacks """ + """Connection class for RPC replies / callbacks.""" def __init__(self, conf, connection_pool): self._call_waiters = {} self._num_call_waiters = 0 @@ -197,8 +199,10 @@ class ReplyProxy(ConnectionContext): msg_id = message_data.pop('_msg_id', None) waiter = self._call_waiters.get(msg_id) if not waiter: - LOG.warn(_('no calling threads waiting for msg_id : %s' - ', message : %s') % (msg_id, message_data)) + LOG.warn(_('No calling threads waiting for msg_id : %(msg_id)s' + ', message : %(data)s'), {'msg_id': msg_id, + 'data': message_data}) + LOG.warn(_('_call_waiters: %s') % str(self._call_waiters)) else: waiter.put(message_data) @@ -231,12 +235,7 @@ def msg_reply(conf, msg_id, reply_q, connection_pool, reply=None, failure = rpc_common.serialize_remote_exception(failure, log_failure) - try: - msg = {'result': reply, 'failure': failure} - except TypeError: - msg = {'result': dict((k, repr(v)) - for k, v in reply.__dict__.iteritems()), - 'failure': failure} + msg = {'result': reply, 'failure': failure} if ending: msg['ending'] = True _add_unique_id(msg) @@ -251,7 +250,7 @@ def msg_reply(conf, msg_id, reply_q, connection_pool, reply=None, class RpcContext(rpc_common.CommonRpcContext): - """Context that supports replying to a rpc.call""" + """Context that supports replying to a rpc.call.""" def __init__(self, **kwargs): self.msg_id = kwargs.pop('msg_id', None) self.reply_q = kwargs.pop('reply_q', None) @@ -301,8 +300,13 @@ def pack_context(msg, context): for args at some point. """ - context_d = dict([('_context_%s' % key, value) - for (key, value) in context.to_dict().iteritems()]) + if isinstance(context, dict): + context_d = dict([('_context_%s' % key, value) + for (key, value) in context.iteritems()]) + else: + context_d = dict([('_context_%s' % key, value) + for (key, value) in context.to_dict().iteritems()]) + msg.update(context_d) @@ -338,8 +342,9 @@ def _add_unique_id(msg): class _ThreadPoolWithWait(object): - """Base class for a delayed invocation manager used by - the Connection class to start up green threads + """Base class for a delayed invocation manager. + + Used by the Connection class to start up green threads to handle incoming messages. """ @@ -354,12 +359,14 @@ class _ThreadPoolWithWait(object): class CallbackWrapper(_ThreadPoolWithWait): - """Wraps a straight callback to allow it to be invoked in a green - thread. + """Wraps a straight callback. + + Allows it to be invoked in a green thread. """ def __init__(self, conf, callback, connection_pool): - """ + """Initiates CallbackWrapper object. + :param conf: cfg.CONF instance :param callback: a callable (probably a function) :param connection_pool: connection pool as returned by @@ -408,15 +415,17 @@ class ProxyCallback(_ThreadPoolWithWait): ctxt = unpack_context(self.conf, message_data) method = message_data.get('method') args = message_data.get('args', {}) - version = message_data.get('version', None) + version = message_data.get('version') + namespace = message_data.get('namespace') if not method: LOG.warn(_('no method for message: %s') % message_data) ctxt.reply(_('No method for message: %s') % message_data, connection_pool=self.connection_pool) return - self.pool.spawn_n(self._process_data, ctxt, version, method, args) + self.pool.spawn_n(self._process_data, ctxt, version, method, + namespace, args) - def _process_data(self, ctxt, version, method, args): + def _process_data(self, ctxt, version, method, namespace, args): """Process a message in a new thread. If the proxy object we have has a dispatch method @@ -427,7 +436,8 @@ class ProxyCallback(_ThreadPoolWithWait): """ ctxt.update_store() try: - rval = self.proxy.dispatch(ctxt, version, method, **args) + rval = self.proxy.dispatch(ctxt, version, method, namespace, + **args) # Check if the result was a generator if inspect.isgenerator(rval): for x in rval: @@ -487,7 +497,7 @@ class MulticallProxyWaiter(object): return result def __iter__(self): - """Return a result until we get a reply with an 'ending" flag""" + """Return a result until we get a reply with an 'ending' flag.""" if self._done: raise StopIteration while True: @@ -509,61 +519,8 @@ class MulticallProxyWaiter(object): yield result -#TODO(pekowski): Remove MulticallWaiter() in Havana. -class MulticallWaiter(object): - def __init__(self, conf, connection, timeout): - self._connection = connection - self._iterator = connection.iterconsume(timeout=timeout or - conf.rpc_response_timeout) - self._result = None - self._done = False - self._got_ending = False - self._conf = conf - self.msg_id_cache = _MsgIdCache() - - def done(self): - if self._done: - return - self._done = True - self._iterator.close() - self._iterator = None - self._connection.close() - - def __call__(self, data): - """The consume() callback will call this. Store the result.""" - self.msg_id_cache.check_duplicate_message(data) - if data['failure']: - failure = data['failure'] - self._result = rpc_common.deserialize_remote_exception(self._conf, - failure) - - elif data.get('ending', False): - self._got_ending = True - else: - self._result = data['result'] - - def __iter__(self): - """Return a result until we get a 'None' response from consumer""" - if self._done: - raise StopIteration - while True: - try: - self._iterator.next() - except Exception: - with excutils.save_and_reraise_exception(): - self.done() - if self._got_ending: - self.done() - raise StopIteration - result = self._result - if isinstance(result, Exception): - self.done() - raise result - yield result - - def create_connection(conf, new, connection_pool): - """Create a connection""" + """Create a connection.""" return ConnectionContext(conf, connection_pool, pooled=not new) @@ -572,14 +529,6 @@ _reply_proxy_create_sem = semaphore.Semaphore() def multicall(conf, context, topic, msg, timeout, connection_pool): """Make a call that returns multiple times.""" - # TODO(pekowski): Remove all these comments in Havana. - # For amqp_rpc_single_reply_queue = False, - # Can't use 'with' for multicall, as it returns an iterator - # that will continue to use the connection. When it's done, - # connection.close() will get called which will put it back into - # the pool - # For amqp_rpc_single_reply_queue = True, - # The 'with' statement is mandatory for closing the connection LOG.debug(_('Making synchronous call on %s ...'), topic) msg_id = uuid.uuid4().hex msg.update({'_msg_id': msg_id}) @@ -587,21 +536,13 @@ def multicall(conf, context, topic, msg, timeout, connection_pool): _add_unique_id(msg) pack_context(msg, context) - # TODO(pekowski): Remove this flag and the code under the if clause - # in Havana. - if not conf.amqp_rpc_single_reply_queue: - conn = ConnectionContext(conf, connection_pool) - wait_msg = MulticallWaiter(conf, conn, timeout) - conn.declare_direct_consumer(msg_id, wait_msg) + with _reply_proxy_create_sem: + if not connection_pool.reply_proxy: + connection_pool.reply_proxy = ReplyProxy(conf, connection_pool) + msg.update({'_reply_q': connection_pool.reply_proxy.get_reply_q()}) + wait_msg = MulticallProxyWaiter(conf, msg_id, timeout, connection_pool) + with ConnectionContext(conf, connection_pool) as conn: conn.topic_send(topic, rpc_common.serialize_msg(msg), timeout) - else: - with _reply_proxy_create_sem: - if not connection_pool.reply_proxy: - connection_pool.reply_proxy = ReplyProxy(conf, connection_pool) - msg.update({'_reply_q': connection_pool.reply_proxy.get_reply_q()}) - wait_msg = MulticallProxyWaiter(conf, msg_id, timeout, connection_pool) - with ConnectionContext(conf, connection_pool) as conn: - conn.topic_send(topic, rpc_common.serialize_msg(msg), timeout) return wait_msg diff --git a/barbican/openstack/common/rpc/common.py b/barbican/openstack/common/rpc/common.py index c6d99a4dc..623b82013 100644 --- a/barbican/openstack/common/rpc/common.py +++ b/barbican/openstack/common/rpc/common.py @@ -22,8 +22,9 @@ import sys import traceback from oslo.config import cfg +import six -from barbican.openstack.common.gettextutils import _ +from barbican.openstack.common.gettextutils import _ # noqa from barbican.openstack.common import importutils from barbican.openstack.common import jsonutils from barbican.openstack.common import local @@ -69,16 +70,18 @@ _RPC_ENVELOPE_VERSION = '2.0' _VERSION_KEY = 'oslo.version' _MESSAGE_KEY = 'oslo.message' +_REMOTE_POSTFIX = '_Remote' + class RPCException(Exception): - message = _("An unknown RPC related exception occurred.") + msg_fmt = _("An unknown RPC related exception occurred.") def __init__(self, message=None, **kwargs): self.kwargs = kwargs if not message: try: - message = self.message % kwargs + message = self.msg_fmt % kwargs except Exception: # kwargs doesn't match a variable in the message @@ -87,7 +90,7 @@ class RPCException(Exception): for name, value in kwargs.iteritems(): LOG.error("%s: %s" % (name, value)) # at least get the core message out if something happened - message = self.message + message = self.msg_fmt super(RPCException, self).__init__(message) @@ -101,7 +104,7 @@ class RemoteError(RPCException): contains all of the relevant info. """ - message = _("Remote error: %(exc_type)s %(value)s\n%(traceback)s.") + msg_fmt = _("Remote error: %(exc_type)s %(value)s\n%(traceback)s.") def __init__(self, exc_type=None, value=None, traceback=None): self.exc_type = exc_type @@ -118,12 +121,13 @@ class Timeout(RPCException): This exception is raised if the rpc_response_timeout is reached while waiting for a response from the remote side. """ - message = _('Timeout while waiting on RPC response - ' + msg_fmt = _('Timeout while waiting on RPC response - ' 'topic: "%(topic)s", RPC method: "%(method)s" ' 'info: "%(info)s"') def __init__(self, info=None, topic=None, method=None): - """ + """Initiates Timeout object. + :param info: Extra info to convey to the user :param topic: The topic that the rpc call was sent to :param rpc_method_name: The name of the rpc method being @@ -140,23 +144,27 @@ class Timeout(RPCException): class DuplicateMessageError(RPCException): - message = _("Found duplicate message(%(msg_id)s). Skipping it.") + msg_fmt = _("Found duplicate message(%(msg_id)s). Skipping it.") class InvalidRPCConnectionReuse(RPCException): - message = _("Invalid reuse of an RPC connection.") + msg_fmt = _("Invalid reuse of an RPC connection.") class UnsupportedRpcVersion(RPCException): - message = _("Specified RPC version, %(version)s, not supported by " + msg_fmt = _("Specified RPC version, %(version)s, not supported by " "this endpoint.") class UnsupportedRpcEnvelopeVersion(RPCException): - message = _("Specified RPC envelope version, %(version)s, " + msg_fmt = _("Specified RPC envelope version, %(version)s, " "not supported by this endpoint.") +class RpcVersionCapError(RPCException): + msg_fmt = _("Specified RPC version cap, %(version_cap)s, is too low") + + class Connection(object): """A connection, returned by rpc.create_connection(). @@ -216,9 +224,9 @@ class Connection(object): raise NotImplementedError() def join_consumer_pool(self, callback, pool_name, topic, exchange_name): - """Register as a member of a group of consumers for a given topic from - the specified exchange. + """Register as a member of a group of consumers. + Uses given topic from the specified exchange. Exactly one member of a given pool will receive each message. A message will be delivered to multiple pools, if more than @@ -253,41 +261,20 @@ class Connection(object): def _safe_log(log_func, msg, msg_data): """Sanitizes the msg_data field before logging.""" - SANITIZE = {'set_admin_password': [('args', 'new_pass')], - 'run_instance': [('args', 'admin_password')], - 'route_message': [('args', 'message', 'args', 'method_info', - 'method_kwargs', 'password'), - ('args', 'message', 'args', 'method_info', - 'method_kwargs', 'admin_password')]} + SANITIZE = ['_context_auth_token', 'auth_token', 'new_pass'] - has_method = 'method' in msg_data and msg_data['method'] in SANITIZE - has_context_token = '_context_auth_token' in msg_data - has_token = 'auth_token' in msg_data + def _fix_passwords(d): + """Sanitizes the password fields in the dictionary.""" + for k in d.iterkeys(): + if k.lower().find('password') != -1: + d[k] = '' + elif k.lower() in SANITIZE: + d[k] = '' + elif isinstance(d[k], dict): + _fix_passwords(d[k]) + return d - if not any([has_method, has_context_token, has_token]): - return log_func(msg, msg_data) - - msg_data = copy.deepcopy(msg_data) - - if has_method: - for arg in SANITIZE.get(msg_data['method'], []): - try: - d = msg_data - for elem in arg[:-1]: - d = d[elem] - d[arg[-1]] = '' - except KeyError, e: - LOG.info(_('Failed to sanitize %(item)s. Key error %(err)s'), - {'item': arg, - 'err': e}) - - if has_context_token: - msg_data['_context_auth_token'] = '' - - if has_token: - msg_data['auth_token'] = '' - - return log_func(msg, msg_data) + return log_func(msg, _fix_passwords(copy.deepcopy(msg_data))) def serialize_remote_exception(failure_info, log_failure=True): @@ -299,17 +286,27 @@ def serialize_remote_exception(failure_info, log_failure=True): tb = traceback.format_exception(*failure_info) failure = failure_info[1] if log_failure: - LOG.error(_("Returning exception %s to caller"), unicode(failure)) + LOG.error(_("Returning exception %s to caller"), + six.text_type(failure)) LOG.error(tb) kwargs = {} if hasattr(failure, 'kwargs'): kwargs = failure.kwargs + # NOTE(matiu): With cells, it's possible to re-raise remote, remote + # exceptions. Lets turn it back into the original exception type. + cls_name = str(failure.__class__.__name__) + mod_name = str(failure.__class__.__module__) + if (cls_name.endswith(_REMOTE_POSTFIX) and + mod_name.endswith(_REMOTE_POSTFIX)): + cls_name = cls_name[:-len(_REMOTE_POSTFIX)] + mod_name = mod_name[:-len(_REMOTE_POSTFIX)] + data = { - 'class': str(failure.__class__.__name__), - 'module': str(failure.__class__.__module__), - 'message': unicode(failure), + 'class': cls_name, + 'module': mod_name, + 'message': six.text_type(failure), 'tb': tb, 'args': failure.args, 'kwargs': kwargs @@ -345,8 +342,9 @@ def deserialize_remote_exception(conf, data): ex_type = type(failure) str_override = lambda self: message - new_ex_type = type(ex_type.__name__ + "_Remote", (ex_type,), + new_ex_type = type(ex_type.__name__ + _REMOTE_POSTFIX, (ex_type,), {'__str__': str_override, '__unicode__': str_override}) + new_ex_type.__module__ = '%s%s' % (module, _REMOTE_POSTFIX) try: # NOTE(ameade): Dynamically create a new exception type and swap it in # as the new type for the exception. This only works on user defined @@ -408,10 +406,11 @@ class CommonRpcContext(object): class ClientException(Exception): - """This encapsulates some actual exception that is expected to be - hit by an RPC proxy object. Merely instantiating it records the - current exception information, which will be passed back to the - RPC client without exceptional logging.""" + """Encapsulates actual exception expected to be hit by a RPC proxy object. + + Merely instantiating it records the current exception information, which + will be passed back to the RPC client without exceptional logging. + """ def __init__(self): self._exc_info = sys.exc_info() @@ -419,7 +418,7 @@ class ClientException(Exception): def catch_client_exception(exceptions, func, *args, **kwargs): try: return func(*args, **kwargs) - except Exception, e: + except Exception as e: if type(e) in exceptions: raise ClientException() else: @@ -428,11 +427,13 @@ def catch_client_exception(exceptions, func, *args, **kwargs): def client_exceptions(*exceptions): """Decorator for manager methods that raise expected exceptions. + Marking a Manager method with this decorator allows the declaration of expected exceptions that the RPC layer should not consider fatal, and not log as if they were generated in a real error scenario. Note that this will cause listed exceptions to be wrapped in a - ClientException, which is used internally by the RPC layer.""" + ClientException, which is used internally by the RPC layer. + """ def outer(func): def inner(*args, **kwargs): return catch_client_exception(exceptions, func, *args, **kwargs) diff --git a/barbican/openstack/common/rpc/dispatcher.py b/barbican/openstack/common/rpc/dispatcher.py index 4d058e3f6..88122584b 100644 --- a/barbican/openstack/common/rpc/dispatcher.py +++ b/barbican/openstack/common/rpc/dispatcher.py @@ -84,6 +84,7 @@ minimum version that supports the new parameter should be specified. """ from barbican.openstack.common.rpc import common as rpc_common +from barbican.openstack.common.rpc import serializer as rpc_serializer class RpcDispatcher(object): @@ -93,23 +94,48 @@ class RpcDispatcher(object): contains a list of underlying managers that have an API_VERSION attribute. """ - def __init__(self, callbacks): + def __init__(self, callbacks, serializer=None): """Initialize the rpc dispatcher. :param callbacks: List of proxy objects that are an instance of a class with rpc methods exposed. Each proxy object should have an RPC_API_VERSION attribute. + :param serializer: The Serializer object that will be used to + deserialize arguments before the method call and + to serialize the result after it returns. """ self.callbacks = callbacks + if serializer is None: + serializer = rpc_serializer.NoOpSerializer() + self.serializer = serializer super(RpcDispatcher, self).__init__() - def dispatch(self, ctxt, version, method, **kwargs): + def _deserialize_args(self, context, kwargs): + """Helper method called to deserialize args before dispatch. + + This calls our serializer on each argument, returning a new set of + args that have been deserialized. + + :param context: The request context + :param kwargs: The arguments to be deserialized + :returns: A new set of deserialized args + """ + new_kwargs = dict() + for argname, arg in kwargs.iteritems(): + new_kwargs[argname] = self.serializer.deserialize_entity(context, + arg) + return new_kwargs + + def dispatch(self, ctxt, version, method, namespace, **kwargs): """Dispatch a message based on a requested version. :param ctxt: The request context :param version: The requested API version from the incoming message :param method: The method requested to be called by the incoming message. + :param namespace: The namespace for the requested method. If None, + the dispatcher will look for a method on a callback + object with no namespace set. :param kwargs: A dict of keyword arguments to be passed to the method. :returns: Whatever is returned by the underlying method that gets @@ -120,17 +146,31 @@ class RpcDispatcher(object): had_compatible = False for proxyobj in self.callbacks: - if hasattr(proxyobj, 'RPC_API_VERSION'): + # Check for namespace compatibility + try: + cb_namespace = proxyobj.RPC_API_NAMESPACE + except AttributeError: + cb_namespace = None + + if namespace != cb_namespace: + continue + + # Check for version compatibility + try: rpc_api_version = proxyobj.RPC_API_VERSION - else: + except AttributeError: rpc_api_version = '1.0' + is_compatible = rpc_common.version_is_compatible(rpc_api_version, version) had_compatible = had_compatible or is_compatible + if not hasattr(proxyobj, method): continue if is_compatible: - return getattr(proxyobj, method)(ctxt, **kwargs) + kwargs = self._deserialize_args(ctxt, kwargs) + result = getattr(proxyobj, method)(ctxt, **kwargs) + return self.serializer.serialize_entity(ctxt, result) if had_compatible: raise AttributeError("No such RPC function '%s'" % method) diff --git a/barbican/openstack/common/rpc/impl_fake.py b/barbican/openstack/common/rpc/impl_fake.py index f7b6ceed2..267972942 100644 --- a/barbican/openstack/common/rpc/impl_fake.py +++ b/barbican/openstack/common/rpc/impl_fake.py @@ -57,13 +57,14 @@ class Consumer(object): self.topic = topic self.proxy = proxy - def call(self, context, version, method, args, timeout): + def call(self, context, version, method, namespace, args, timeout): done = eventlet.event.Event() def _inner(): ctxt = RpcContext.from_dict(context.to_dict()) try: - rval = self.proxy.dispatch(context, version, method, **args) + rval = self.proxy.dispatch(context, version, method, + namespace, **args) res = [] # Caller might have called ctxt.reply() manually for (reply, failure) in ctxt._response: @@ -121,7 +122,7 @@ class Connection(object): def create_connection(conf, new=True): - """Create a connection""" + """Create a connection.""" return Connection() @@ -140,13 +141,15 @@ def multicall(conf, context, topic, msg, timeout=None): return args = msg.get('args', {}) version = msg.get('version', None) + namespace = msg.get('namespace', None) try: consumer = CONSUMERS[topic][0] except (KeyError, IndexError): return iter([None]) else: - return consumer.call(context, version, method, args, timeout) + return consumer.call(context, version, method, namespace, args, + timeout) def call(conf, context, topic, msg, timeout=None): @@ -176,16 +179,17 @@ def cleanup(): def fanout_cast(conf, context, topic, msg): - """Cast to all consumers of a topic""" + """Cast to all consumers of a topic.""" check_serialize(msg) method = msg.get('method') if not method: return args = msg.get('args', {}) version = msg.get('version', None) + namespace = msg.get('namespace', None) for consumer in CONSUMERS.get(topic, []): try: - consumer.call(context, version, method, args, None) + consumer.call(context, version, method, namespace, args, None) except Exception: pass diff --git a/barbican/openstack/common/rpc/impl_kombu.py b/barbican/openstack/common/rpc/impl_kombu.py index bcc481ebb..555d1e252 100644 --- a/barbican/openstack/common/rpc/impl_kombu.py +++ b/barbican/openstack/common/rpc/impl_kombu.py @@ -18,7 +18,6 @@ import functools import itertools import socket import ssl -import sys import time import uuid @@ -30,15 +29,20 @@ import kombu.entity import kombu.messaging from oslo.config import cfg -from barbican.openstack.common.gettextutils import _ +from barbican.openstack.common import excutils +from barbican.openstack.common.gettextutils import _ # noqa from barbican.openstack.common import network_utils from barbican.openstack.common.rpc import amqp as rpc_amqp from barbican.openstack.common.rpc import common as rpc_common +from barbican.openstack.common import sslutils kombu_opts = [ cfg.StrOpt('kombu_ssl_version', default='', - help='SSL version to use (valid only if SSL enabled)'), + help='SSL version to use (valid only if SSL enabled). ' + 'valid values are TLSv1, SSLv23 and SSLv3. SSLv2 may ' + 'be available on some distributions' + ), cfg.StrOpt('kombu_ssl_keyfile', default='', help='SSL key file (valid only if SSL enabled)'), @@ -82,9 +86,6 @@ kombu_opts = [ default=0, help='maximum retries with trying to connect to RabbitMQ ' '(the default of 0 implies an infinite retry count)'), - cfg.BoolOpt('rabbit_durable_queues', - default=False, - help='use durable queues in RabbitMQ'), cfg.BoolOpt('rabbit_ha_queues', default=False, help='use H/A queues in RabbitMQ (x-ha-policy: all).' @@ -129,15 +130,46 @@ class ConsumerBase(object): self.tag = str(tag) self.kwargs = kwargs self.queue = None + self.ack_on_error = kwargs.get('ack_on_error', True) self.reconnect(channel) def reconnect(self, channel): - """Re-declare the queue after a rabbit reconnect""" + """Re-declare the queue after a rabbit reconnect.""" self.channel = channel self.kwargs['channel'] = channel self.queue = kombu.entity.Queue(**self.kwargs) self.queue.declare() + def _callback_handler(self, message, callback): + """Call callback with deserialized message. + + Messages that are processed without exception are ack'ed. + + If the message processing generates an exception, it will be + ack'ed if ack_on_error=True. Otherwise it will be .reject()'ed. + Rejection is better than waiting for the message to timeout. + Rejected messages are immediately requeued. + """ + + ack_msg = False + try: + msg = rpc_common.deserialize_msg(message.payload) + callback(msg) + ack_msg = True + except Exception: + if self.ack_on_error: + ack_msg = True + LOG.exception(_("Failed to process message" + " ... skipping it.")) + else: + LOG.exception(_("Failed to process message" + " ... will requeue.")) + finally: + if ack_msg: + message.ack() + else: + message.reject() + def consume(self, *args, **kwargs): """Actually declare the consumer on the amqp channel. This will start the flow of messages from the queue. Using the @@ -150,8 +182,6 @@ class ConsumerBase(object): If kwargs['nowait'] is True, then this call will block until a message is read. - Messages will automatically be acked if the callback doesn't - raise an exception """ options = {'consumer_tag': self.tag} @@ -162,21 +192,15 @@ class ConsumerBase(object): def _callback(raw_message): message = self.channel.message_to_python(raw_message) - try: - msg = rpc_common.deserialize_msg(message.payload) - callback(msg) - except Exception: - LOG.exception(_("Failed to process message... skipping it.")) - finally: - message.ack() + self._callback_handler(message, callback) self.queue.consume(*args, callback=_callback, **options) def cancel(self): - """Cancel the consuming from the queue, if it has started""" + """Cancel the consuming from the queue, if it has started.""" try: self.queue.cancel(self.tag) - except KeyError, e: + except KeyError as e: # NOTE(comstud): Kludge to get around a amqplib bug if str(e) != "u'%s'" % self.tag: raise @@ -184,7 +208,7 @@ class ConsumerBase(object): class DirectConsumer(ConsumerBase): - """Queue/consumer class for 'direct'""" + """Queue/consumer class for 'direct'.""" def __init__(self, conf, channel, msg_id, callback, tag, **kwargs): """Init a 'direct' queue. @@ -216,7 +240,7 @@ class DirectConsumer(ConsumerBase): class TopicConsumer(ConsumerBase): - """Consumer class for 'topic'""" + """Consumer class for 'topic'.""" def __init__(self, conf, channel, topic, callback, tag, name=None, exchange_name=None, **kwargs): @@ -233,9 +257,9 @@ class TopicConsumer(ConsumerBase): Other kombu options may be passed as keyword arguments """ # Default options - options = {'durable': conf.rabbit_durable_queues, + options = {'durable': conf.amqp_durable_queues, 'queue_arguments': _get_queue_arguments(conf), - 'auto_delete': False, + 'auto_delete': conf.amqp_auto_delete, 'exclusive': False} options.update(kwargs) exchange_name = exchange_name or rpc_amqp.get_control_exchange(conf) @@ -253,7 +277,7 @@ class TopicConsumer(ConsumerBase): class FanoutConsumer(ConsumerBase): - """Consumer class for 'fanout'""" + """Consumer class for 'fanout'.""" def __init__(self, conf, channel, topic, callback, tag, **kwargs): """Init a 'fanout' queue. @@ -286,7 +310,7 @@ class FanoutConsumer(ConsumerBase): class Publisher(object): - """Base Publisher class""" + """Base Publisher class.""" def __init__(self, channel, exchange_name, routing_key, **kwargs): """Init the Publisher class with the exchange_name, routing_key, @@ -298,7 +322,7 @@ class Publisher(object): self.reconnect(channel) def reconnect(self, channel): - """Re-establish the Producer after a rabbit reconnection""" + """Re-establish the Producer after a rabbit reconnection.""" self.exchange = kombu.entity.Exchange(name=self.exchange_name, **self.kwargs) self.producer = kombu.messaging.Producer(exchange=self.exchange, @@ -306,7 +330,7 @@ class Publisher(object): routing_key=self.routing_key) def send(self, msg, timeout=None): - """Send a message""" + """Send a message.""" if timeout: # # AMQP TTL is in milliseconds when set in the header. @@ -317,7 +341,7 @@ class Publisher(object): class DirectPublisher(Publisher): - """Publisher class for 'direct'""" + """Publisher class for 'direct'.""" def __init__(self, conf, channel, msg_id, **kwargs): """init a 'direct' publisher. @@ -333,14 +357,14 @@ class DirectPublisher(Publisher): class TopicPublisher(Publisher): - """Publisher class for 'topic'""" + """Publisher class for 'topic'.""" def __init__(self, conf, channel, topic, **kwargs): """init a 'topic' publisher. Kombu options may be passed as keyword args to override defaults """ - options = {'durable': conf.rabbit_durable_queues, - 'auto_delete': False, + options = {'durable': conf.amqp_durable_queues, + 'auto_delete': conf.amqp_auto_delete, 'exclusive': False} options.update(kwargs) exchange_name = rpc_amqp.get_control_exchange(conf) @@ -352,7 +376,7 @@ class TopicPublisher(Publisher): class FanoutPublisher(Publisher): - """Publisher class for 'fanout'""" + """Publisher class for 'fanout'.""" def __init__(self, conf, channel, topic, **kwargs): """init a 'fanout' publisher. @@ -367,10 +391,10 @@ class FanoutPublisher(Publisher): class NotifyPublisher(TopicPublisher): - """Publisher class for 'notify'""" + """Publisher class for 'notify'.""" def __init__(self, conf, channel, topic, **kwargs): - self.durable = kwargs.pop('durable', conf.rabbit_durable_queues) + self.durable = kwargs.pop('durable', conf.amqp_durable_queues) self.queue_arguments = _get_queue_arguments(conf) super(NotifyPublisher, self).__init__(conf, channel, topic, **kwargs) @@ -447,13 +471,15 @@ class Connection(object): self.reconnect() def _fetch_ssl_params(self): - """Handles fetching what ssl params - should be used for the connection (if any)""" + """Handles fetching what ssl params should be used for the connection + (if any). + """ ssl_params = dict() # http://docs.python.org/library/ssl.html - ssl.wrap_socket if self.conf.kombu_ssl_version: - ssl_params['ssl_version'] = self.conf.kombu_ssl_version + ssl_params['ssl_version'] = sslutils.validate_ssl_version( + self.conf.kombu_ssl_version) if self.conf.kombu_ssl_keyfile: ssl_params['keyfile'] = self.conf.kombu_ssl_keyfile if self.conf.kombu_ssl_certfile: @@ -464,12 +490,8 @@ class Connection(object): # future with this? ssl_params['cert_reqs'] = ssl.CERT_REQUIRED - if not ssl_params: - # Just have the default behavior - return True - else: - # Return the extended behavior - return ssl_params + # Return the extended behavior or just have the default behavior + return ssl_params or True def _connect(self, params): """Connect to rabbit. Re-establish any queues that may have @@ -520,7 +542,7 @@ class Connection(object): return except (IOError, self.connection_errors) as e: pass - except Exception, e: + except Exception as e: # NOTE(comstud): Unfortunately it's possible for amqplib # to return an error not covered by its transport # connection_errors in the case of a timeout waiting for @@ -536,13 +558,11 @@ class Connection(object): log_info.update(params) if self.max_retries and attempt == self.max_retries: - LOG.error(_('Unable to connect to AMQP server on ' - '%(hostname)s:%(port)d after %(max_retries)d ' - 'tries: %(err_str)s') % log_info) - # NOTE(comstud): Copied from original code. There's - # really no better recourse because if this was a queue we - # need to consume on, we have no way to consume anymore. - sys.exit(1) + msg = _('Unable to connect to AMQP server on ' + '%(hostname)s:%(port)d after %(max_retries)d ' + 'tries: %(err_str)s') % log_info + LOG.error(msg) + raise rpc_common.RPCException(msg) if attempt == 1: sleep_time = self.interval_start or 1 @@ -561,10 +581,10 @@ class Connection(object): while True: try: return method(*args, **kwargs) - except (self.connection_errors, socket.timeout, IOError), e: + except (self.connection_errors, socket.timeout, IOError) as e: if error_callback: error_callback(e) - except Exception, e: + except Exception as e: # NOTE(comstud): Unfortunately it's possible for amqplib # to return an error not covered by its transport # connection_errors in the case of a timeout waiting for @@ -578,18 +598,18 @@ class Connection(object): self.reconnect() def get_channel(self): - """Convenience call for bin/clear_rabbit_queues""" + """Convenience call for bin/clear_rabbit_queues.""" return self.channel def close(self): - """Close/release this connection""" + """Close/release this connection.""" self.cancel_consumer_thread() self.wait_on_proxy_callbacks() self.connection.release() self.connection = None def reset(self): - """Reset a connection so it can be used again""" + """Reset a connection so it can be used again.""" self.cancel_consumer_thread() self.wait_on_proxy_callbacks() self.channel.close() @@ -618,7 +638,7 @@ class Connection(object): return self.ensure(_connect_error, _declare_consumer) def iterconsume(self, limit=None, timeout=None): - """Return an iterator that will consume from all queues/consumers""" + """Return an iterator that will consume from all queues/consumers.""" info = {'do_consume': True} @@ -634,8 +654,8 @@ class Connection(object): def _consume(): if info['do_consume']: - queues_head = self.consumers[:-1] - queues_tail = self.consumers[-1] + queues_head = self.consumers[:-1] # not fanout. + queues_tail = self.consumers[-1] # fanout for queue in queues_head: queue.consume(nowait=True) queues_tail.consume(nowait=False) @@ -648,7 +668,7 @@ class Connection(object): yield self.ensure(_error_callback, _consume) def cancel_consumer_thread(self): - """Cancel a consumer thread""" + """Cancel a consumer thread.""" if self.consumer_thread is not None: self.consumer_thread.kill() try: @@ -663,7 +683,7 @@ class Connection(object): proxy_cb.wait() def publisher_send(self, cls, topic, msg, timeout=None, **kwargs): - """Send to a publisher based on the publisher class""" + """Send to a publisher based on the publisher class.""" def _error_callback(exc): log_info = {'topic': topic, 'err_str': str(exc)} @@ -684,36 +704,37 @@ class Connection(object): self.declare_consumer(DirectConsumer, topic, callback) def declare_topic_consumer(self, topic, callback=None, queue_name=None, - exchange_name=None): + exchange_name=None, ack_on_error=True): """Create a 'topic' consumer.""" self.declare_consumer(functools.partial(TopicConsumer, name=queue_name, exchange_name=exchange_name, + ack_on_error=ack_on_error, ), topic, callback) def declare_fanout_consumer(self, topic, callback): - """Create a 'fanout' consumer""" + """Create a 'fanout' consumer.""" self.declare_consumer(FanoutConsumer, topic, callback) def direct_send(self, msg_id, msg): - """Send a 'direct' message""" + """Send a 'direct' message.""" self.publisher_send(DirectPublisher, msg_id, msg) def topic_send(self, topic, msg, timeout=None): - """Send a 'topic' message""" + """Send a 'topic' message.""" self.publisher_send(TopicPublisher, topic, msg, timeout) def fanout_send(self, topic, msg): - """Send a 'fanout' message""" + """Send a 'fanout' message.""" self.publisher_send(FanoutPublisher, topic, msg) def notify_send(self, topic, msg, **kwargs): - """Send a notify message on a topic""" + """Send a notify message on a topic.""" self.publisher_send(NotifyPublisher, topic, msg, None, **kwargs) def consume(self, limit=None): - """Consume from all queues/consumers""" + """Consume from all queues/consumers.""" it = self.iterconsume(limit=limit) while True: try: @@ -722,7 +743,8 @@ class Connection(object): return def consume_in_thread(self): - """Consumer from all queues/consumers in a greenthread""" + """Consumer from all queues/consumers in a greenthread.""" + @excutils.forever_retry_uncaught_exceptions def _consumer_thread(): try: self.consume() @@ -733,7 +755,7 @@ class Connection(object): return self.consumer_thread def create_consumer(self, topic, proxy, fanout=False): - """Create a consumer that calls a method in a proxy object""" + """Create a consumer that calls a method in a proxy object.""" proxy_cb = rpc_amqp.ProxyCallback( self.conf, proxy, rpc_amqp.get_connection_pool(self.conf, Connection)) @@ -745,7 +767,7 @@ class Connection(object): self.declare_topic_consumer(topic, proxy_cb) def create_worker(self, topic, proxy, pool_name): - """Create a worker that calls a method in a proxy object""" + """Create a worker that calls a method in a proxy object.""" proxy_cb = rpc_amqp.ProxyCallback( self.conf, proxy, rpc_amqp.get_connection_pool(self.conf, Connection)) @@ -753,7 +775,7 @@ class Connection(object): self.declare_topic_consumer(topic, proxy_cb, pool_name) def join_consumer_pool(self, callback, pool_name, topic, - exchange_name=None): + exchange_name=None, ack_on_error=True): """Register as a member of a group of consumers for a given topic from the specified exchange. @@ -774,11 +796,12 @@ class Connection(object): topic=topic, exchange_name=exchange_name, callback=callback_wrapper, + ack_on_error=ack_on_error, ) def create_connection(conf, new=True): - """Create a connection""" + """Create a connection.""" return rpc_amqp.create_connection( conf, new, rpc_amqp.get_connection_pool(conf, Connection)) diff --git a/barbican/openstack/common/rpc/impl_qpid.py b/barbican/openstack/common/rpc/impl_qpid.py index f8a53a5a1..ed35e1b22 100644 --- a/barbican/openstack/common/rpc/impl_qpid.py +++ b/barbican/openstack/common/rpc/impl_qpid.py @@ -24,13 +24,15 @@ import eventlet import greenlet from oslo.config import cfg -from barbican.openstack.common.gettextutils import _ +from barbican.openstack.common import excutils +from barbican.openstack.common.gettextutils import _ # noqa from barbican.openstack.common import importutils from barbican.openstack.common import jsonutils from barbican.openstack.common import log as logging from barbican.openstack.common.rpc import amqp as rpc_amqp from barbican.openstack.common.rpc import common as rpc_common +qpid_codec = importutils.try_import("qpid.codec010") qpid_messaging = importutils.try_import("qpid.messaging") qpid_exceptions = importutils.try_import("qpid.messaging.exceptions") @@ -69,6 +71,8 @@ qpid_opts = [ cfg.CONF.register_opts(qpid_opts) +JSON_CONTENT_TYPE = 'application/json; charset=utf8' + class ConsumerBase(object): """Consumer base class.""" @@ -115,31 +119,59 @@ class ConsumerBase(object): self.address = "%s ; %s" % (node_name, jsonutils.dumps(addr_opts)) - self.reconnect(session) + self.connect(session) + + def connect(self, session): + """Declare the reciever on connect.""" + self._declare_receiver(session) def reconnect(self, session): - """Re-declare the receiver after a qpid reconnect""" + """Re-declare the receiver after a qpid reconnect.""" + self._declare_receiver(session) + + def _declare_receiver(self, session): self.session = session self.receiver = session.receiver(self.address) self.receiver.capacity = 1 + def _unpack_json_msg(self, msg): + """Load the JSON data in msg if msg.content_type indicates that it + is necessary. Put the loaded data back into msg.content and + update msg.content_type appropriately. + + A Qpid Message containing a dict will have a content_type of + 'amqp/map', whereas one containing a string that needs to be converted + back from JSON will have a content_type of JSON_CONTENT_TYPE. + + :param msg: a Qpid Message object + :returns: None + """ + if msg.content_type == JSON_CONTENT_TYPE: + msg.content = jsonutils.loads(msg.content) + msg.content_type = 'amqp/map' + def consume(self): - """Fetch the message and pass it to the callback object""" + """Fetch the message and pass it to the callback object.""" message = self.receiver.fetch() try: + self._unpack_json_msg(message) msg = rpc_common.deserialize_msg(message.content) self.callback(msg) except Exception: LOG.exception(_("Failed to process message... skipping it.")) finally: + # TODO(sandy): Need support for optional ack_on_error. self.session.acknowledge(message) def get_receiver(self): return self.receiver + def get_node_name(self): + return self.address.split(';')[0] + class DirectConsumer(ConsumerBase): - """Queue/consumer class for 'direct'""" + """Queue/consumer class for 'direct'.""" def __init__(self, conf, session, msg_id, callback): """Init a 'direct' queue. @@ -149,15 +181,20 @@ class DirectConsumer(ConsumerBase): 'callback' is the callback to call when messages are received """ - super(DirectConsumer, self).__init__(session, callback, - "%s/%s" % (msg_id, msg_id), - {"type": "direct"}, - msg_id, - {"exclusive": True}) + super(DirectConsumer, self).__init__( + session, callback, + "%s/%s" % (msg_id, msg_id), + {"type": "direct"}, + msg_id, + { + "auto-delete": conf.amqp_auto_delete, + "exclusive": True, + "durable": conf.amqp_durable_queues, + }) class TopicConsumer(ConsumerBase): - """Consumer class for 'topic'""" + """Consumer class for 'topic'.""" def __init__(self, conf, session, topic, callback, name=None, exchange_name=None): @@ -171,13 +208,18 @@ class TopicConsumer(ConsumerBase): """ exchange_name = exchange_name or rpc_amqp.get_control_exchange(conf) - super(TopicConsumer, self).__init__(session, callback, - "%s/%s" % (exchange_name, topic), - {}, name or topic, {}) + super(TopicConsumer, self).__init__( + session, callback, + "%s/%s" % (exchange_name, topic), + {}, name or topic, + { + "auto-delete": conf.amqp_auto_delete, + "durable": conf.amqp_durable_queues, + }) class FanoutConsumer(ConsumerBase): - """Consumer class for 'fanout'""" + """Consumer class for 'fanout'.""" def __init__(self, conf, session, topic, callback): """Init a 'fanout' queue. @@ -186,6 +228,7 @@ class FanoutConsumer(ConsumerBase): 'topic' is the topic to listen on 'callback' is the callback to call when messages are received """ + self.conf = conf super(FanoutConsumer, self).__init__( session, callback, @@ -194,9 +237,21 @@ class FanoutConsumer(ConsumerBase): "%s_fanout_%s" % (topic, uuid.uuid4().hex), {"exclusive": True}) + def reconnect(self, session): + topic = self.get_node_name().rpartition('_fanout')[0] + params = { + 'session': session, + 'topic': topic, + 'callback': self.callback, + } + + self.__init__(conf=self.conf, **params) + + super(FanoutConsumer, self).reconnect(session) + class Publisher(object): - """Base Publisher class""" + """Base Publisher class.""" def __init__(self, session, node_name, node_opts=None): """Init the Publisher class with the exchange_name, routing_key, @@ -225,24 +280,51 @@ class Publisher(object): self.reconnect(session) def reconnect(self, session): - """Re-establish the Sender after a reconnection""" + """Re-establish the Sender after a reconnection.""" self.sender = session.sender(self.address) + def _pack_json_msg(self, msg): + """Qpid cannot serialize dicts containing strings longer than 65535 + characters. This function dumps the message content to a JSON + string, which Qpid is able to handle. + + :param msg: May be either a Qpid Message object or a bare dict. + :returns: A Qpid Message with its content field JSON encoded. + """ + try: + msg.content = jsonutils.dumps(msg.content) + except AttributeError: + # Need to have a Qpid message so we can set the content_type. + msg = qpid_messaging.Message(jsonutils.dumps(msg)) + msg.content_type = JSON_CONTENT_TYPE + return msg + def send(self, msg): - """Send a message""" + """Send a message.""" + try: + # Check if Qpid can encode the message + check_msg = msg + if not hasattr(check_msg, 'content_type'): + check_msg = qpid_messaging.Message(msg) + content_type = check_msg.content_type + enc, dec = qpid_messaging.message.get_codec(content_type) + enc(check_msg.content) + except qpid_codec.CodecException: + # This means the message couldn't be serialized as a dict. + msg = self._pack_json_msg(msg) self.sender.send(msg) class DirectPublisher(Publisher): - """Publisher class for 'direct'""" + """Publisher class for 'direct'.""" def __init__(self, conf, session, msg_id): """Init a 'direct' publisher.""" super(DirectPublisher, self).__init__(session, msg_id, - {"type": "Direct"}) + {"type": "direct"}) class TopicPublisher(Publisher): - """Publisher class for 'topic'""" + """Publisher class for 'topic'.""" def __init__(self, conf, session, topic): """init a 'topic' publisher. """ @@ -252,7 +334,7 @@ class TopicPublisher(Publisher): class FanoutPublisher(Publisher): - """Publisher class for 'fanout'""" + """Publisher class for 'fanout'.""" def __init__(self, conf, session, topic): """init a 'fanout' publisher. """ @@ -262,7 +344,7 @@ class FanoutPublisher(Publisher): class NotifyPublisher(Publisher): - """Publisher class for notifications""" + """Publisher class for notifications.""" def __init__(self, conf, session, topic): """init a 'topic' publisher. """ @@ -330,23 +412,24 @@ class Connection(object): return self.consumers[str(receiver)] def reconnect(self): - """Handles reconnecting and re-establishing sessions and queues""" - if self.connection.opened(): - try: - self.connection.close() - except qpid_exceptions.ConnectionError: - pass - + """Handles reconnecting and re-establishing sessions and queues.""" attempt = 0 delay = 1 while True: + # Close the session if necessary + if self.connection.opened(): + try: + self.connection.close() + except qpid_exceptions.ConnectionError: + pass + broker = self.brokers[attempt % len(self.brokers)] attempt += 1 try: self.connection_create(broker) self.connection.open() - except qpid_exceptions.ConnectionError, e: + except qpid_exceptions.ConnectionError as e: msg_dict = dict(e=e, delay=delay) msg = _("Unable to connect to AMQP server: %(e)s. " "Sleeping %(delay)s seconds") % msg_dict @@ -374,20 +457,26 @@ class Connection(object): try: return method(*args, **kwargs) except (qpid_exceptions.Empty, - qpid_exceptions.ConnectionError), e: + qpid_exceptions.ConnectionError) as e: if error_callback: error_callback(e) self.reconnect() def close(self): - """Close/release this connection""" + """Close/release this connection.""" self.cancel_consumer_thread() self.wait_on_proxy_callbacks() - self.connection.close() + try: + self.connection.close() + except Exception: + # NOTE(dripton) Logging exceptions that happen during cleanup just + # causes confusion; there's really nothing useful we can do with + # them. + pass self.connection = None def reset(self): - """Reset a connection so it can be used again""" + """Reset a connection so it can be used again.""" self.cancel_consumer_thread() self.wait_on_proxy_callbacks() self.session.close() @@ -411,7 +500,7 @@ class Connection(object): return self.ensure(_connect_error, _declare_consumer) def iterconsume(self, limit=None, timeout=None): - """Return an iterator that will consume from all queues/consumers""" + """Return an iterator that will consume from all queues/consumers.""" def _error_callback(exc): if isinstance(exc, qpid_exceptions.Empty): @@ -435,7 +524,7 @@ class Connection(object): yield self.ensure(_error_callback, _consume) def cancel_consumer_thread(self): - """Cancel a consumer thread""" + """Cancel a consumer thread.""" if self.consumer_thread is not None: self.consumer_thread.kill() try: @@ -450,7 +539,7 @@ class Connection(object): proxy_cb.wait() def publisher_send(self, cls, topic, msg): - """Send to a publisher based on the publisher class""" + """Send to a publisher based on the publisher class.""" def _connect_error(exc): log_info = {'topic': topic, 'err_str': str(exc)} @@ -480,15 +569,15 @@ class Connection(object): topic, callback) def declare_fanout_consumer(self, topic, callback): - """Create a 'fanout' consumer""" + """Create a 'fanout' consumer.""" self.declare_consumer(FanoutConsumer, topic, callback) def direct_send(self, msg_id, msg): - """Send a 'direct' message""" + """Send a 'direct' message.""" self.publisher_send(DirectPublisher, msg_id, msg) def topic_send(self, topic, msg, timeout=None): - """Send a 'topic' message""" + """Send a 'topic' message.""" # # We want to create a message with attributes, e.g. a TTL. We # don't really need to keep 'msg' in its JSON format any longer @@ -503,15 +592,15 @@ class Connection(object): self.publisher_send(TopicPublisher, topic, qpid_message) def fanout_send(self, topic, msg): - """Send a 'fanout' message""" + """Send a 'fanout' message.""" self.publisher_send(FanoutPublisher, topic, msg) def notify_send(self, topic, msg, **kwargs): - """Send a notify message on a topic""" + """Send a notify message on a topic.""" self.publisher_send(NotifyPublisher, topic, msg) def consume(self, limit=None): - """Consume from all queues/consumers""" + """Consume from all queues/consumers.""" it = self.iterconsume(limit=limit) while True: try: @@ -520,7 +609,8 @@ class Connection(object): return def consume_in_thread(self): - """Consumer from all queues/consumers in a greenthread""" + """Consumer from all queues/consumers in a greenthread.""" + @excutils.forever_retry_uncaught_exceptions def _consumer_thread(): try: self.consume() @@ -531,7 +621,7 @@ class Connection(object): return self.consumer_thread def create_consumer(self, topic, proxy, fanout=False): - """Create a consumer that calls a method in a proxy object""" + """Create a consumer that calls a method in a proxy object.""" proxy_cb = rpc_amqp.ProxyCallback( self.conf, proxy, rpc_amqp.get_connection_pool(self.conf, Connection)) @@ -547,7 +637,7 @@ class Connection(object): return consumer def create_worker(self, topic, proxy, pool_name): - """Create a worker that calls a method in a proxy object""" + """Create a worker that calls a method in a proxy object.""" proxy_cb = rpc_amqp.ProxyCallback( self.conf, proxy, rpc_amqp.get_connection_pool(self.conf, Connection)) @@ -561,7 +651,7 @@ class Connection(object): return consumer def join_consumer_pool(self, callback, pool_name, topic, - exchange_name=None): + exchange_name=None, ack_on_error=True): """Register as a member of a group of consumers for a given topic from the specified exchange. @@ -590,7 +680,7 @@ class Connection(object): def create_connection(conf, new=True): - """Create a connection""" + """Create a connection.""" return rpc_amqp.create_connection( conf, new, rpc_amqp.get_connection_pool(conf, Connection)) diff --git a/barbican/openstack/common/rpc/impl_zmq.py b/barbican/openstack/common/rpc/impl_zmq.py index 19d7d4e94..e68c002f9 100644 --- a/barbican/openstack/common/rpc/impl_zmq.py +++ b/barbican/openstack/common/rpc/impl_zmq.py @@ -27,10 +27,9 @@ import greenlet from oslo.config import cfg from barbican.openstack.common import excutils -from barbican.openstack.common.gettextutils import _ +from barbican.openstack.common.gettextutils import _ # noqa from barbican.openstack.common import importutils from barbican.openstack.common import jsonutils -from barbican.openstack.common import processutils as utils from barbican.openstack.common.rpc import common as rpc_common zmq = importutils.try_import('eventlet.green.zmq') @@ -85,8 +84,8 @@ matchmaker = None # memoized matchmaker object def _serialize(data): - """ - Serialization wrapper + """Serialization wrapper. + We prefer using JSON, but it cannot encode all types. Error if a developer passes us bad data. """ @@ -98,18 +97,15 @@ def _serialize(data): def _deserialize(data): - """ - Deserialization wrapper - """ + """Deserialization wrapper.""" LOG.debug(_("Deserializing: %s"), data) return jsonutils.loads(data) class ZmqSocket(object): - """ - A tiny wrapper around ZeroMQ to simplify the send/recv protocol - and connection management. + """A tiny wrapper around ZeroMQ. + Simplifies the send/recv protocol and connection management. Can be used as a Context (supports the 'with' statement). """ @@ -180,7 +176,7 @@ class ZmqSocket(object): return # We must unsubscribe, or we'll leak descriptors. - if len(self.subscriptions) > 0: + if self.subscriptions: for f in self.subscriptions: try: self.sock.setsockopt(zmq.UNSUBSCRIBE, f) @@ -199,26 +195,24 @@ class ZmqSocket(object): LOG.error("ZeroMQ socket could not be closed.") self.sock = None - def recv(self): + def recv(self, **kwargs): if not self.can_recv: raise RPCException(_("You cannot recv on this socket.")) - return self.sock.recv_multipart() + return self.sock.recv_multipart(**kwargs) - def send(self, data): + def send(self, data, **kwargs): if not self.can_send: raise RPCException(_("You cannot send on this socket.")) - self.sock.send_multipart(data) + self.sock.send_multipart(data, **kwargs) class ZmqClient(object): """Client for ZMQ sockets.""" - def __init__(self, addr, socket_type=None, bind=False): - if socket_type is None: - socket_type = zmq.PUSH - self.outq = ZmqSocket(addr, socket_type, bind=bind) + def __init__(self, addr): + self.outq = ZmqSocket(addr, zmq.PUSH, bind=False) - def cast(self, msg_id, topic, data, envelope=False): + def cast(self, msg_id, topic, data, envelope): msg_id = msg_id or 0 if not envelope: @@ -276,12 +270,13 @@ class InternalContext(object): try: result = proxy.dispatch( - ctx, data['version'], data['method'], **data['args']) + ctx, data['version'], data['method'], + data.get('namespace'), **data['args']) return ConsumerBase.normalize_reply(result, ctx.replies) except greenlet.GreenletExit: # ignore these since they are just from shutdowns pass - except rpc_common.ClientException, e: + except rpc_common.ClientException as e: LOG.debug(_("Expected exception during message handling (%s)") % e._exc_info[1]) return {'exc': @@ -351,20 +346,18 @@ class ConsumerBase(object): return proxy.dispatch(ctx, data['version'], - data['method'], **data['args']) + data['method'], data.get('namespace'), **data['args']) class ZmqBaseReactor(ConsumerBase): - """ - A consumer class implementing a - centralized casting broker (PULL-PUSH) - for RoundRobin requests. + """A consumer class implementing a centralized casting broker (PULL-PUSH). + + Used for RoundRobin requests. """ def __init__(self, conf): super(ZmqBaseReactor, self).__init__() - self.mapping = {} self.proxies = {} self.threads = [] self.sockets = [] @@ -372,9 +365,8 @@ class ZmqBaseReactor(ConsumerBase): self.pool = eventlet.greenpool.GreenPool(conf.rpc_thread_pool_size) - def register(self, proxy, in_addr, zmq_type_in, out_addr=None, - zmq_type_out=None, in_bind=True, out_bind=True, - subscribe=None): + def register(self, proxy, in_addr, zmq_type_in, + in_bind=True, subscribe=None): LOG.info(_("Registering reactor")) @@ -390,21 +382,6 @@ class ZmqBaseReactor(ConsumerBase): LOG.info(_("In reactor registered")) - if not out_addr: - return - - if zmq_type_out not in (zmq.PUSH, zmq.PUB): - raise RPCException("Bad output socktype") - - # Items push out. - outq = ZmqSocket(out_addr, zmq_type_out, bind=out_bind) - - self.mapping[inq] = outq - self.mapping[outq] = inq - self.sockets.append(outq) - - LOG.info(_("Out reactor registered")) - def consume_in_thread(self): def _consume(sock): LOG.info(_("Consuming socket")) @@ -429,10 +406,9 @@ class ZmqBaseReactor(ConsumerBase): class ZmqProxy(ZmqBaseReactor): - """ - A consumer class implementing a - topic-based proxy, forwarding to - IPC sockets. + """A consumer class implementing a topic-based proxy. + + Forwards to IPC sockets. """ def __init__(self, conf): @@ -445,11 +421,8 @@ class ZmqProxy(ZmqBaseReactor): def consume(self, sock): ipc_dir = CONF.rpc_zmq_ipc_dir - #TODO(ewindisch): use zero-copy (i.e. references, not copying) - data = sock.recv() - topic = data[1] - - LOG.debug(_("CONSUMER GOT %s"), ' '.join(map(pformat, data))) + data = sock.recv(copy=False) + topic = data[1].bytes if topic.startswith('fanout~'): sock_type = zmq.PUB @@ -491,9 +464,7 @@ class ZmqProxy(ZmqBaseReactor): while(True): data = self.topic_proxy[topic].get() - out_sock.send(data) - LOG.debug(_("ROUTER RELAY-OUT SUCCEEDED %(data)s") % - {'data': data}) + out_sock.send(data, copy=False) wait_sock_creation = eventlet.event.Event() eventlet.spawn(publisher, wait_sock_creation) @@ -506,37 +477,34 @@ class ZmqProxy(ZmqBaseReactor): try: self.topic_proxy[topic].put_nowait(data) - LOG.debug(_("ROUTER RELAY-OUT QUEUED %(data)s") % - {'data': data}) except eventlet.queue.Full: LOG.error(_("Local per-topic backlog buffer full for topic " "%(topic)s. Dropping message.") % {'topic': topic}) def consume_in_thread(self): - """Runs the ZmqProxy service""" + """Runs the ZmqProxy service.""" ipc_dir = CONF.rpc_zmq_ipc_dir consume_in = "tcp://%s:%s" % \ (CONF.rpc_zmq_bind_address, CONF.rpc_zmq_port) consumption_proxy = InternalContext(None) - if not os.path.isdir(ipc_dir): - try: - utils.execute('mkdir', '-p', ipc_dir, run_as_root=True) - utils.execute('chown', "%s:%s" % (os.getuid(), os.getgid()), - ipc_dir, run_as_root=True) - utils.execute('chmod', '750', ipc_dir, run_as_root=True) - except utils.ProcessExecutionError: + try: + os.makedirs(ipc_dir) + except os.error: + if not os.path.isdir(ipc_dir): with excutils.save_and_reraise_exception(): - LOG.error(_("Could not create IPC directory %s") % - (ipc_dir, )) - + LOG.error(_("Required IPC directory does not exist at" + " %s") % (ipc_dir, )) try: self.register(consumption_proxy, consume_in, - zmq.PULL, - out_bind=True) + zmq.PULL) except zmq.ZMQError: + if os.access(ipc_dir, os.X_OK): + with excutils.save_and_reraise_exception(): + LOG.error(_("Permission denied to IPC directory at" + " %s") % (ipc_dir, )) with excutils.save_and_reraise_exception(): LOG.error(_("Could not create ZeroMQ receiver daemon. " "Socket may already be in use.")) @@ -546,8 +514,9 @@ class ZmqProxy(ZmqBaseReactor): def unflatten_envelope(packenv): """Unflattens the RPC envelope. - Takes a list and returns a dictionary. - i.e. [1,2,3,4] => {1: 2, 3: 4} + + Takes a list and returns a dictionary. + i.e. [1,2,3,4] => {1: 2, 3: 4} """ i = iter(packenv) h = {} @@ -560,10 +529,9 @@ def unflatten_envelope(packenv): class ZmqReactor(ZmqBaseReactor): - """ - A consumer class implementing a - consumer for messages. Can also be - used as a 1:1 proxy + """A consumer class implementing a consumer for messages. + + Can also be used as a 1:1 proxy """ def __init__(self, conf): @@ -573,11 +541,6 @@ class ZmqReactor(ZmqBaseReactor): #TODO(ewindisch): use zero-copy (i.e. references, not copying) data = sock.recv() LOG.debug(_("CONSUMER RECEIVED DATA: %s"), data) - if sock in self.mapping: - LOG.debug(_("ROUTER RELAY-OUT %(data)s") % { - 'data': data}) - self.mapping[sock].send(data) - return proxy = self.proxies[sock] @@ -750,10 +713,9 @@ def _call(addr, context, topic, msg, timeout=None, def _multi_send(method, context, topic, msg, timeout=None, envelope=False, _msg_id=None): - """ - Wraps the sending of messages, - dispatches to the matchmaker and sends - message to all relevant hosts. + """Wraps the sending of messages. + + Dispatches to the matchmaker and sends message to all relevant hosts. """ conf = CONF LOG.debug(_("%(msg)s") % {'msg': ' '.join(map(pformat, (topic, msg)))}) @@ -762,7 +724,7 @@ def _multi_send(method, context, topic, msg, timeout=None, LOG.debug(_("Sending message(s) to: %s"), queues) # Don't stack if we have no matchmaker results - if len(queues) == 0: + if not queues: LOG.warn(_("No matchmaker results. Not casting.")) # While not strictly a timeout, callers know how to handle # this exception and a timeout isn't too big a lie. @@ -810,8 +772,8 @@ def fanout_cast(conf, context, topic, msg, **kwargs): def notify(conf, context, topic, msg, envelope): - """ - Send notification event. + """Send notification event. + Notifications are sent to topic-priority. This differs from the AMQP drivers which send to topic.priority. """ @@ -845,6 +807,11 @@ def _get_ctxt(): def _get_matchmaker(*args, **kwargs): global matchmaker if not matchmaker: - matchmaker = importutils.import_object( - CONF.rpc_zmq_matchmaker, *args, **kwargs) + mm = CONF.rpc_zmq_matchmaker + if mm.endswith('matchmaker.MatchMakerRing'): + mm.replace('matchmaker', 'matchmaker_ring') + LOG.warn(_('rpc_zmq_matchmaker = %(orig)s is deprecated; use' + ' %(new)s instead') % dict( + orig=CONF.rpc_zmq_matchmaker, new=mm)) + matchmaker = importutils.import_object(mm, *args, **kwargs) return matchmaker diff --git a/barbican/openstack/common/rpc/matchmaker.py b/barbican/openstack/common/rpc/matchmaker.py index 53ddac4c9..c2274c93e 100644 --- a/barbican/openstack/common/rpc/matchmaker.py +++ b/barbican/openstack/common/rpc/matchmaker.py @@ -19,21 +19,15 @@ return keys for direct exchanges, per (approximate) AMQP parlance. """ import contextlib -import itertools -import json import eventlet from oslo.config import cfg -from barbican.openstack.common.gettextutils import _ +from barbican.openstack.common.gettextutils import _ # noqa from barbican.openstack.common import log as logging matchmaker_opts = [ - # Matchmaker ring file - cfg.StrOpt('matchmaker_ringfile', - default='/etc/nova/matchmaker_ring.json', - help='Matchmaker ring file (JSON)'), cfg.IntOpt('matchmaker_heartbeat_freq', default=300, help='Heartbeat frequency'), @@ -54,8 +48,8 @@ class MatchMakerException(Exception): class Exchange(object): - """ - Implements lookups. + """Implements lookups. + Subclass this to support hashtables, dns, etc. """ def __init__(self): @@ -66,9 +60,7 @@ class Exchange(object): class Binding(object): - """ - A binding on which to perform a lookup. - """ + """A binding on which to perform a lookup.""" def __init__(self): pass @@ -77,10 +69,10 @@ class Binding(object): class MatchMakerBase(object): - """ - Match Maker Base Class. - Build off HeartbeatMatchMakerBase if building a - heartbeat-capable MatchMaker. + """Match Maker Base Class. + + Build off HeartbeatMatchMakerBase if building a heartbeat-capable + MatchMaker. """ def __init__(self): # Array of tuples. Index [2] toggles negation, [3] is last-if-true @@ -90,58 +82,47 @@ class MatchMakerBase(object): 'registration or heartbeat.') def register(self, key, host): - """ - Register a host on a backend. + """Register a host on a backend. + Heartbeats, if applicable, may keepalive registration. """ pass def ack_alive(self, key, host): - """ - Acknowledge that a key.host is alive. - Used internally for updating heartbeats, - but may also be used publically to acknowledge - a system is alive (i.e. rpc message successfully - sent to host) + """Acknowledge that a key.host is alive. + + Used internally for updating heartbeats, but may also be used + publically to acknowledge a system is alive (i.e. rpc message + successfully sent to host) """ pass def is_alive(self, topic, host): - """ - Checks if a host is alive. - """ + """Checks if a host is alive.""" pass def expire(self, topic, host): - """ - Explicitly expire a host's registration. - """ + """Explicitly expire a host's registration.""" pass def send_heartbeats(self): - """ - Send all heartbeats. + """Send all heartbeats. + Use start_heartbeat to spawn a heartbeat greenthread, which loops this method. """ pass def unregister(self, key, host): - """ - Unregister a topic. - """ + """Unregister a topic.""" pass def start_heartbeat(self): - """ - Spawn heartbeat greenthread. - """ + """Spawn heartbeat greenthread.""" pass def stop_heartbeat(self): - """ - Destroys the heartbeat greenthread. - """ + """Destroys the heartbeat greenthread.""" pass def add_binding(self, binding, rule, last=True): @@ -168,10 +149,10 @@ class MatchMakerBase(object): class HeartbeatMatchMakerBase(MatchMakerBase): - """ - Base for a heart-beat capable MatchMaker. - Provides common methods for registering, - unregistering, and maintaining heartbeats. + """Base for a heart-beat capable MatchMaker. + + Provides common methods for registering, unregistering, and maintaining + heartbeats. """ def __init__(self): self.hosts = set() @@ -181,8 +162,8 @@ class HeartbeatMatchMakerBase(MatchMakerBase): super(HeartbeatMatchMakerBase, self).__init__() def send_heartbeats(self): - """ - Send all heartbeats. + """Send all heartbeats. + Use start_heartbeat to spawn a heartbeat greenthread, which loops this method. """ @@ -190,32 +171,31 @@ class HeartbeatMatchMakerBase(MatchMakerBase): self.ack_alive(key, host) def ack_alive(self, key, host): - """ - Acknowledge that a host.topic is alive. - Used internally for updating heartbeats, - but may also be used publically to acknowledge - a system is alive (i.e. rpc message successfully - sent to host) + """Acknowledge that a host.topic is alive. + + Used internally for updating heartbeats, but may also be used + publically to acknowledge a system is alive (i.e. rpc message + successfully sent to host) """ raise NotImplementedError("Must implement ack_alive") def backend_register(self, key, host): - """ - Implements registration logic. + """Implements registration logic. + Called by register(self,key,host) """ raise NotImplementedError("Must implement backend_register") def backend_unregister(self, key, key_host): - """ - Implements de-registration logic. + """Implements de-registration logic. + Called by unregister(self,key,host) """ raise NotImplementedError("Must implement backend_unregister") def register(self, key, host): - """ - Register a host on a backend. + """Register a host on a backend. + Heartbeats, if applicable, may keepalive registration. """ self.hosts.add(host) @@ -227,25 +207,24 @@ class HeartbeatMatchMakerBase(MatchMakerBase): self.ack_alive(key, host) def unregister(self, key, host): - """ - Unregister a topic. - """ + """Unregister a topic.""" if (key, host) in self.host_topic: del self.host_topic[(key, host)] self.hosts.discard(host) self.backend_unregister(key, '.'.join((key, host))) - LOG.info(_("Matchmaker unregistered: %s, %s" % (key, host))) + LOG.info(_("Matchmaker unregistered: %(key)s, %(host)s"), + {'key': key, 'host': host}) def start_heartbeat(self): - """ - Implementation of MatchMakerBase.start_heartbeat + """Implementation of MatchMakerBase.start_heartbeat. + Launches greenthread looping send_heartbeats(), yielding for CONF.matchmaker_heartbeat_freq seconds between iterations. """ - if len(self.hosts) == 0: + if not self.hosts: raise MatchMakerException( _("Register before starting heartbeat.")) @@ -257,45 +236,37 @@ class HeartbeatMatchMakerBase(MatchMakerBase): self._heart = eventlet.spawn(do_heartbeat) def stop_heartbeat(self): - """ - Destroys the heartbeat greenthread. - """ + """Destroys the heartbeat greenthread.""" if self._heart: self._heart.kill() class DirectBinding(Binding): - """ - Specifies a host in the key via a '.' character + """Specifies a host in the key via a '.' character. + Although dots are used in the key, the behavior here is that it maps directly to a host, thus direct. """ def test(self, key): - if '.' in key: - return True - return False + return '.' in key class TopicBinding(Binding): - """ - Where a 'bare' key without dots. + """Where a 'bare' key without dots. + AMQP generally considers topic exchanges to be those *with* dots, but we deviate here in terminology as the behavior here matches that of a topic exchange (whereas where there are dots, behavior matches that of a direct exchange. """ def test(self, key): - if '.' not in key: - return True - return False + return '.' not in key class FanoutBinding(Binding): """Match on fanout keys, where key starts with 'fanout.' string.""" def test(self, key): - if key.startswith('fanout~'): - return True - return False + return key.startswith('fanout~') class StubExchange(Exchange): @@ -304,67 +275,6 @@ class StubExchange(Exchange): return [(key, None)] -class RingExchange(Exchange): - """ - Match Maker where hosts are loaded from a static file containing - a hashmap (JSON formatted). - - __init__ takes optional ring dictionary argument, otherwise - loads the ringfile from CONF.mathcmaker_ringfile. - """ - def __init__(self, ring=None): - super(RingExchange, self).__init__() - - if ring: - self.ring = ring - else: - fh = open(CONF.matchmaker_ringfile, 'r') - self.ring = json.load(fh) - fh.close() - - self.ring0 = {} - for k in self.ring.keys(): - self.ring0[k] = itertools.cycle(self.ring[k]) - - def _ring_has(self, key): - if key in self.ring0: - return True - return False - - -class RoundRobinRingExchange(RingExchange): - """A Topic Exchange based on a hashmap.""" - def __init__(self, ring=None): - super(RoundRobinRingExchange, self).__init__(ring) - - def run(self, key): - if not self._ring_has(key): - LOG.warn( - _("No key defining hosts for topic '%s', " - "see ringfile") % (key, ) - ) - return [] - host = next(self.ring0[key]) - return [(key + '.' + host, host)] - - -class FanoutRingExchange(RingExchange): - """Fanout Exchange based on a hashmap.""" - def __init__(self, ring=None): - super(FanoutRingExchange, self).__init__(ring) - - def run(self, key): - # Assume starts with "fanout~", strip it for lookup. - nkey = key.split('fanout~')[1:][0] - if not self._ring_has(nkey): - LOG.warn( - _("No key defining hosts for topic '%s', " - "see ringfile") % (nkey, ) - ) - return [] - return map(lambda x: (key + '.' + x, x), self.ring[nkey]) - - class LocalhostExchange(Exchange): """Exchange where all direct topics are local.""" def __init__(self, host='localhost'): @@ -376,8 +286,8 @@ class LocalhostExchange(Exchange): class DirectExchange(Exchange): - """ - Exchange where all topic keys are split, sending to second half. + """Exchange where all topic keys are split, sending to second half. + i.e. "compute.host" sends a message to "compute.host" running on "host" """ def __init__(self): @@ -388,20 +298,9 @@ class DirectExchange(Exchange): return [(key, e)] -class MatchMakerRing(MatchMakerBase): - """ - Match Maker where hosts are loaded from a static hashmap. - """ - def __init__(self, ring=None): - super(MatchMakerRing, self).__init__() - self.add_binding(FanoutBinding(), FanoutRingExchange(ring)) - self.add_binding(DirectBinding(), DirectExchange()) - self.add_binding(TopicBinding(), RoundRobinRingExchange(ring)) - - class MatchMakerLocalhost(MatchMakerBase): - """ - Match Maker where all bare topics resolve to localhost. + """Match Maker where all bare topics resolve to localhost. + Useful for testing. """ def __init__(self, host='localhost'): @@ -412,13 +311,13 @@ class MatchMakerLocalhost(MatchMakerBase): class MatchMakerStub(MatchMakerBase): - """ - Match Maker where topics are untouched. + """Match Maker where topics are untouched. + Useful for testing, or for AMQP/brokered queues. Will not work where knowledge of hosts is known (i.e. zeromq) """ def __init__(self): - super(MatchMakerLocalhost, self).__init__() + super(MatchMakerStub, self).__init__() self.add_binding(FanoutBinding(), StubExchange()) self.add_binding(DirectBinding(), StubExchange()) diff --git a/barbican/openstack/common/rpc/matchmaker_redis.py b/barbican/openstack/common/rpc/matchmaker_redis.py index 767407d81..f9e6796fb 100644 --- a/barbican/openstack/common/rpc/matchmaker_redis.py +++ b/barbican/openstack/common/rpc/matchmaker_redis.py @@ -55,8 +55,8 @@ class RedisExchange(mm_common.Exchange): class RedisTopicExchange(RedisExchange): - """ - Exchange where all topic keys are split, sending to second half. + """Exchange where all topic keys are split, sending to second half. + i.e. "compute.host" sends a message to "compute" running on "host" """ def run(self, topic): @@ -77,9 +77,7 @@ class RedisTopicExchange(RedisExchange): class RedisFanoutExchange(RedisExchange): - """ - Return a list of all hosts. - """ + """Return a list of all hosts.""" def run(self, topic): topic = topic.split('~', 1)[1] hosts = self.redis.smembers(topic) @@ -90,9 +88,7 @@ class RedisFanoutExchange(RedisExchange): class MatchMakerRedis(mm_common.HeartbeatMatchMakerBase): - """ - MatchMaker registering and looking-up hosts with a Redis server. - """ + """MatchMaker registering and looking-up hosts with a Redis server.""" def __init__(self): super(MatchMakerRedis, self).__init__() diff --git a/barbican/openstack/common/rpc/matchmaker_ring.py b/barbican/openstack/common/rpc/matchmaker_ring.py new file mode 100644 index 000000000..07dd25a16 --- /dev/null +++ b/barbican/openstack/common/rpc/matchmaker_ring.py @@ -0,0 +1,108 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011-2013 Cloudscaling Group, Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +""" +The MatchMaker classes should except a Topic or Fanout exchange key and +return keys for direct exchanges, per (approximate) AMQP parlance. +""" + +import itertools +import json + +from oslo.config import cfg + +from barbican.openstack.common.gettextutils import _ # noqa +from barbican.openstack.common import log as logging +from barbican.openstack.common.rpc import matchmaker as mm + + +matchmaker_opts = [ + # Matchmaker ring file + cfg.StrOpt('ringfile', + deprecated_name='matchmaker_ringfile', + deprecated_group='DEFAULT', + default='/etc/oslo/matchmaker_ring.json', + help='Matchmaker ring file (JSON)'), +] + +CONF = cfg.CONF +CONF.register_opts(matchmaker_opts, 'matchmaker_ring') +LOG = logging.getLogger(__name__) + + +class RingExchange(mm.Exchange): + """Match Maker where hosts are loaded from a static JSON formatted file. + + __init__ takes optional ring dictionary argument, otherwise + loads the ringfile from CONF.mathcmaker_ringfile. + """ + def __init__(self, ring=None): + super(RingExchange, self).__init__() + + if ring: + self.ring = ring + else: + fh = open(CONF.matchmaker_ring.ringfile, 'r') + self.ring = json.load(fh) + fh.close() + + self.ring0 = {} + for k in self.ring.keys(): + self.ring0[k] = itertools.cycle(self.ring[k]) + + def _ring_has(self, key): + return key in self.ring0 + + +class RoundRobinRingExchange(RingExchange): + """A Topic Exchange based on a hashmap.""" + def __init__(self, ring=None): + super(RoundRobinRingExchange, self).__init__(ring) + + def run(self, key): + if not self._ring_has(key): + LOG.warn( + _("No key defining hosts for topic '%s', " + "see ringfile") % (key, ) + ) + return [] + host = next(self.ring0[key]) + return [(key + '.' + host, host)] + + +class FanoutRingExchange(RingExchange): + """Fanout Exchange based on a hashmap.""" + def __init__(self, ring=None): + super(FanoutRingExchange, self).__init__(ring) + + def run(self, key): + # Assume starts with "fanout~", strip it for lookup. + nkey = key.split('fanout~')[1:][0] + if not self._ring_has(nkey): + LOG.warn( + _("No key defining hosts for topic '%s', " + "see ringfile") % (nkey, ) + ) + return [] + return map(lambda x: (key + '.' + x, x), self.ring[nkey]) + + +class MatchMakerRing(mm.MatchMakerBase): + """Match Maker where hosts are loaded from a static hashmap.""" + def __init__(self, ring=None): + super(MatchMakerRing, self).__init__() + self.add_binding(mm.FanoutBinding(), FanoutRingExchange(ring)) + self.add_binding(mm.DirectBinding(), mm.DirectExchange()) + self.add_binding(mm.TopicBinding(), RoundRobinRingExchange(ring)) diff --git a/barbican/openstack/common/rpc/proxy.py b/barbican/openstack/common/rpc/proxy.py index 0a8e70432..d4d036107 100644 --- a/barbican/openstack/common/rpc/proxy.py +++ b/barbican/openstack/common/rpc/proxy.py @@ -1,6 +1,6 @@ # vim: tabstop=4 shiftwidth=4 softtabstop=4 -# Copyright 2012 Red Hat, Inc. +# Copyright 2012-2013 Red Hat, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain @@ -23,6 +23,8 @@ For more information about rpc API version numbers, see: from barbican.openstack.common import rpc +from barbican.openstack.common.rpc import common as rpc_common +from barbican.openstack.common.rpc import serializer as rpc_serializer class RpcProxy(object): @@ -34,16 +36,28 @@ class RpcProxy(object): rpc API. """ - def __init__(self, topic, default_version): + # The default namespace, which can be overriden in a subclass. + RPC_API_NAMESPACE = None + + def __init__(self, topic, default_version, version_cap=None, + serializer=None): """Initialize an RpcProxy. :param topic: The topic to use for all messages. :param default_version: The default API version to request in all outgoing messages. This can be overridden on a per-message basis. + :param version_cap: Optionally cap the maximum version used for sent + messages. + :param serializer: Optionaly (de-)serialize entities with a + provided helper. """ self.topic = topic self.default_version = default_version + self.version_cap = version_cap + if serializer is None: + serializer = rpc_serializer.NoOpSerializer() + self.serializer = serializer super(RpcProxy, self).__init__() def _set_version(self, msg, vers): @@ -52,15 +66,44 @@ class RpcProxy(object): :param msg: The message having a version added to it. :param vers: The version number to add to the message. """ - msg['version'] = vers if vers else self.default_version + v = vers if vers else self.default_version + if (self.version_cap and not + rpc_common.version_is_compatible(self.version_cap, v)): + raise rpc_common.RpcVersionCapError(version_cap=self.version_cap) + msg['version'] = v def _get_topic(self, topic): """Return the topic to use for a message.""" return topic if topic else self.topic + def can_send_version(self, version): + """Check to see if a version is compatible with the version cap.""" + return (not self.version_cap or + rpc_common.version_is_compatible(self.version_cap, version)) + @staticmethod - def make_msg(method, **kwargs): - return {'method': method, 'args': kwargs} + def make_namespaced_msg(method, namespace, **kwargs): + return {'method': method, 'namespace': namespace, 'args': kwargs} + + def make_msg(self, method, **kwargs): + return self.make_namespaced_msg(method, self.RPC_API_NAMESPACE, + **kwargs) + + def _serialize_msg_args(self, context, kwargs): + """Helper method called to serialize message arguments. + + This calls our serializer on each argument, returning a new + set of args that have been serialized. + + :param context: The request context + :param kwargs: The arguments to serialize + :returns: A new set of serialized arguments + """ + new_kwargs = dict() + for argname, arg in kwargs.iteritems(): + new_kwargs[argname] = self.serializer.serialize_entity(context, + arg) + return new_kwargs def call(self, context, msg, topic=None, version=None, timeout=None): """rpc.call() a remote method. @@ -77,9 +120,11 @@ class RpcProxy(object): :returns: The return value from the remote method. """ self._set_version(msg, version) + msg['args'] = self._serialize_msg_args(context, msg['args']) real_topic = self._get_topic(topic) try: - return rpc.call(context, real_topic, msg, timeout) + result = rpc.call(context, real_topic, msg, timeout) + return self.serializer.deserialize_entity(context, result) except rpc.common.Timeout as exc: raise rpc.common.Timeout( exc.info, real_topic, msg.get('method')) @@ -100,9 +145,11 @@ class RpcProxy(object): from the remote method as they arrive. """ self._set_version(msg, version) + msg['args'] = self._serialize_msg_args(context, msg['args']) real_topic = self._get_topic(topic) try: - return rpc.multicall(context, real_topic, msg, timeout) + result = rpc.multicall(context, real_topic, msg, timeout) + return self.serializer.deserialize_entity(context, result) except rpc.common.Timeout as exc: raise rpc.common.Timeout( exc.info, real_topic, msg.get('method')) @@ -120,6 +167,7 @@ class RpcProxy(object): remote method. """ self._set_version(msg, version) + msg['args'] = self._serialize_msg_args(context, msg['args']) rpc.cast(context, self._get_topic(topic), msg) def fanout_cast(self, context, msg, topic=None, version=None): @@ -135,6 +183,7 @@ class RpcProxy(object): from the remote method. """ self._set_version(msg, version) + msg['args'] = self._serialize_msg_args(context, msg['args']) rpc.fanout_cast(context, self._get_topic(topic), msg) def cast_to_server(self, context, server_params, msg, topic=None, @@ -153,6 +202,7 @@ class RpcProxy(object): return values. """ self._set_version(msg, version) + msg['args'] = self._serialize_msg_args(context, msg['args']) rpc.cast_to_server(context, server_params, self._get_topic(topic), msg) def fanout_cast_to_server(self, context, server_params, msg, topic=None, @@ -171,5 +221,6 @@ class RpcProxy(object): return values. """ self._set_version(msg, version) + msg['args'] = self._serialize_msg_args(context, msg['args']) rpc.fanout_cast_to_server(context, server_params, self._get_topic(topic), msg) diff --git a/barbican/openstack/common/rpc/securemessage.py b/barbican/openstack/common/rpc/securemessage.py new file mode 100644 index 000000000..7996d2c24 --- /dev/null +++ b/barbican/openstack/common/rpc/securemessage.py @@ -0,0 +1,521 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import base64 +import collections +import os +import struct +import time + +import requests + +from oslo.config import cfg + +from barbican.openstack.common.crypto import utils as cryptoutils +from barbican.openstack.common import jsonutils +from barbican.openstack.common import log as logging + +secure_message_opts = [ + cfg.BoolOpt('enabled', default=True, + help='Whether Secure Messaging (Signing) is enabled,' + ' defaults to enabled'), + cfg.BoolOpt('enforced', default=False, + help='Whether Secure Messaging (Signing) is enforced,' + ' defaults to not enforced'), + cfg.BoolOpt('encrypt', default=False, + help='Whether Secure Messaging (Encryption) is enabled,' + ' defaults to not enabled'), + cfg.StrOpt('secret_keys_file', + help='Path to the file containing the keys, takes precedence' + ' over secret_key'), + cfg.MultiStrOpt('secret_key', + help='A list of keys: (ex: name:),' + ' ignored if secret_keys_file is set'), + cfg.StrOpt('kds_endpoint', + help='KDS endpoint (ex: http://kds.example.com:35357/v3)'), +] +secure_message_group = cfg.OptGroup('secure_messages', + title='Secure Messaging options') + +LOG = logging.getLogger(__name__) + + +class SecureMessageException(Exception): + """Generic Exception for Secure Messages.""" + + msg = "An unknown Secure Message related exception occurred." + + def __init__(self, msg=None): + if msg is None: + msg = self.msg + super(SecureMessageException, self).__init__(msg) + + +class SharedKeyNotFound(SecureMessageException): + """No shared key was found and no other external authentication mechanism + is available. + """ + + msg = "Shared Key for [%s] Not Found. (%s)" + + def __init__(self, name, errmsg): + super(SharedKeyNotFound, self).__init__(self.msg % (name, errmsg)) + + +class InvalidMetadata(SecureMessageException): + """The metadata is invalid.""" + + msg = "Invalid metadata: %s" + + def __init__(self, err): + super(InvalidMetadata, self).__init__(self.msg % err) + + +class InvalidSignature(SecureMessageException): + """Signature validation failed.""" + + msg = "Failed to validate signature (source=%s, destination=%s)" + + def __init__(self, src, dst): + super(InvalidSignature, self).__init__(self.msg % (src, dst)) + + +class UnknownDestinationName(SecureMessageException): + """The Destination name is unknown to us.""" + + msg = "Invalid destination name (%s)" + + def __init__(self, name): + super(UnknownDestinationName, self).__init__(self.msg % name) + + +class InvalidEncryptedTicket(SecureMessageException): + """The Encrypted Ticket could not be successfully handled.""" + + msg = "Invalid Ticket (source=%s, destination=%s)" + + def __init__(self, src, dst): + super(InvalidEncryptedTicket, self).__init__(self.msg % (src, dst)) + + +class InvalidExpiredTicket(SecureMessageException): + """The ticket received is already expired.""" + + msg = "Expired ticket (source=%s, destination=%s)" + + def __init__(self, src, dst): + super(InvalidExpiredTicket, self).__init__(self.msg % (src, dst)) + + +class CommunicationError(SecureMessageException): + """The Communication with the KDS failed.""" + + msg = "Communication Error (target=%s): %s" + + def __init__(self, target, errmsg): + super(CommunicationError, self).__init__(self.msg % (target, errmsg)) + + +class InvalidArgument(SecureMessageException): + """Bad initialization argument.""" + + msg = "Invalid argument: %s" + + def __init__(self, errmsg): + super(InvalidArgument, self).__init__(self.msg % errmsg) + + +Ticket = collections.namedtuple('Ticket', ['skey', 'ekey', 'esek']) + + +class KeyStore(object): + """A storage class for Signing and Encryption Keys. + + This class creates an object that holds Generic Keys like Signing + Keys, Encryption Keys, Encrypted SEK Tickets ... + """ + + def __init__(self): + self._kvps = dict() + + def _get_key_name(self, source, target, ktype): + return (source, target, ktype) + + def _put(self, src, dst, ktype, expiration, data): + name = self._get_key_name(src, dst, ktype) + self._kvps[name] = (expiration, data) + + def _get(self, src, dst, ktype): + name = self._get_key_name(src, dst, ktype) + if name in self._kvps: + expiration, data = self._kvps[name] + if expiration > time.time(): + return data + else: + del self._kvps[name] + + return None + + def clear(self): + """Wipes the store clear of all data.""" + self._kvps.clear() + + def put_ticket(self, source, target, skey, ekey, esek, expiration): + """Puts a sek pair in the cache. + + :param source: Client name + :param target: Target name + :param skey: The Signing Key + :param ekey: The Encription Key + :param esek: The token encrypted with the target key + :param expiration: Expiration time in seconds since Epoch + """ + keys = Ticket(skey, ekey, esek) + self._put(source, target, 'ticket', expiration, keys) + + def get_ticket(self, source, target): + """Returns a Ticket (skey, ekey, esek) namedtuple for the + source/target pair. + """ + return self._get(source, target, 'ticket') + + +_KEY_STORE = KeyStore() + + +class _KDSClient(object): + + USER_AGENT = 'oslo-incubator/rpc' + + def __init__(self, endpoint=None, timeout=None): + """A KDS Client class.""" + + self._endpoint = endpoint + if timeout is not None: + self.timeout = float(timeout) + else: + self.timeout = None + + def _do_get(self, url, request): + req_kwargs = dict() + req_kwargs['headers'] = dict() + req_kwargs['headers']['User-Agent'] = self.USER_AGENT + req_kwargs['headers']['Content-Type'] = 'application/json' + req_kwargs['data'] = jsonutils.dumps({'request': request}) + if self.timeout is not None: + req_kwargs['timeout'] = self.timeout + + try: + resp = requests.get(url, **req_kwargs) + except requests.ConnectionError as e: + err = "Unable to establish connection. %s" % e + raise CommunicationError(url, err) + + return resp + + def _get_reply(self, url, resp): + if resp.text: + try: + body = jsonutils.loads(resp.text) + reply = body['reply'] + except (KeyError, TypeError, ValueError): + msg = "Failed to decode reply: %s" % resp.text + raise CommunicationError(url, msg) + else: + msg = "No reply data was returned." + raise CommunicationError(url, msg) + + return reply + + def _get_ticket(self, request, url=None, redirects=10): + """Send an HTTP request. + + Wraps around 'requests' to handle redirects and common errors. + """ + if url is None: + if not self._endpoint: + raise CommunicationError(url, 'Endpoint not configured') + url = self._endpoint + '/kds/ticket' + + while redirects: + resp = self._do_get(url, request) + if resp.status_code in (301, 302, 305): + # Redirected. Reissue the request to the new location. + url = resp.headers['location'] + redirects -= 1 + continue + elif resp.status_code != 200: + msg = "Request returned failure status: %s (%s)" + err = msg % (resp.status_code, resp.text) + raise CommunicationError(url, err) + + return self._get_reply(url, resp) + + raise CommunicationError(url, "Too many redirections, giving up!") + + def get_ticket(self, source, target, crypto, key): + + # prepare metadata + md = {'requestor': source, + 'target': target, + 'timestamp': time.time(), + 'nonce': struct.unpack('Q', os.urandom(8))[0]} + metadata = base64.b64encode(jsonutils.dumps(md)) + + # sign metadata + signature = crypto.sign(key, metadata) + + # HTTP request + reply = self._get_ticket({'metadata': metadata, + 'signature': signature}) + + # verify reply + signature = crypto.sign(key, (reply['metadata'] + reply['ticket'])) + if signature != reply['signature']: + raise InvalidEncryptedTicket(md['source'], md['destination']) + md = jsonutils.loads(base64.b64decode(reply['metadata'])) + if ((md['source'] != source or + md['destination'] != target or + md['expiration'] < time.time())): + raise InvalidEncryptedTicket(md['source'], md['destination']) + + # return ticket data + tkt = jsonutils.loads(crypto.decrypt(key, reply['ticket'])) + + return tkt, md['expiration'] + + +# we need to keep a global nonce, as this value should never repeat non +# matter how many SecureMessage objects we create +_NONCE = None + + +def _get_nonce(): + """We keep a single counter per instance, as it is so huge we can't + possibly cycle through within 1/100 of a second anyway. + """ + + global _NONCE + # Lazy initialize, for now get a random value, multiply by 2^32 and + # use it as the nonce base. The counter itself will rotate after + # 2^32 increments. + if _NONCE is None: + _NONCE = [struct.unpack('I', os.urandom(4))[0], 0] + + # Increment counter and wrap at 2^32 + _NONCE[1] += 1 + if _NONCE[1] > 0xffffffff: + _NONCE[1] = 0 + + # Return base + counter + return long((_NONCE[0] * 0xffffffff)) + _NONCE[1] + + +class SecureMessage(object): + """A Secure Message object. + + This class creates a signing/encryption facility for RPC messages. + It encapsulates all the necessary crypto primitives to insulate + regular code from the intricacies of message authentication, validation + and optionally encryption. + + :param topic: The topic name of the queue + :param host: The server name, together with the topic it forms a unique + name that is used to source signing keys, and verify + incoming messages. + :param conf: a ConfigOpts object + :param key: (optional) explicitly pass in endpoint private key. + If not provided it will be sourced from the service config + :param key_store: (optional) Storage class for local caching + :param encrypt: (defaults to False) Whether to encrypt messages + :param enctype: (defaults to AES) Cipher to use + :param hashtype: (defaults to SHA256) Hash function to use for signatures + """ + + def __init__(self, topic, host, conf, key=None, key_store=None, + encrypt=None, enctype='AES', hashtype='SHA256'): + + conf.register_group(secure_message_group) + conf.register_opts(secure_message_opts, group='secure_messages') + + self._name = '%s.%s' % (topic, host) + self._key = key + self._conf = conf.secure_messages + self._encrypt = self._conf.encrypt if (encrypt is None) else encrypt + self._crypto = cryptoutils.SymmetricCrypto(enctype, hashtype) + self._hkdf = cryptoutils.HKDF(hashtype) + self._kds = _KDSClient(self._conf.kds_endpoint) + + if self._key is None: + self._key = self._init_key(topic, self._name) + if self._key is None: + err = "Secret Key (or key file) is missing or malformed" + raise SharedKeyNotFound(self._name, err) + + self._key_store = key_store or _KEY_STORE + + def _init_key(self, topic, name): + keys = None + if self._conf.secret_keys_file: + with open(self._conf.secret_keys_file, 'r') as f: + keys = f.readlines() + elif self._conf.secret_key: + keys = self._conf.secret_key + + if keys is None: + return None + + for k in keys: + if k[0] == '#': + continue + if ':' not in k: + break + svc, key = k.split(':', 1) + if svc == topic or svc == name: + return base64.b64decode(key) + + return None + + def _split_key(self, key, size): + sig_key = key[:size] + enc_key = key[size:] + return sig_key, enc_key + + def _decode_esek(self, key, source, target, timestamp, esek): + """This function decrypts the esek buffer passed in and returns a + KeyStore to be used to check and decrypt the received message. + + :param key: The key to use to decrypt the ticket (esek) + :param source: The name of the source service + :param traget: The name of the target service + :param timestamp: The incoming message timestamp + :param esek: a base64 encoded encrypted block containing a JSON string + """ + rkey = None + + try: + s = self._crypto.decrypt(key, esek) + j = jsonutils.loads(s) + + rkey = base64.b64decode(j['key']) + expiration = j['timestamp'] + j['ttl'] + if j['timestamp'] > timestamp or timestamp > expiration: + raise InvalidExpiredTicket(source, target) + + except Exception: + raise InvalidEncryptedTicket(source, target) + + info = '%s,%s,%s' % (source, target, str(j['timestamp'])) + + sek = self._hkdf.expand(rkey, info, len(key) * 2) + + return self._split_key(sek, len(key)) + + def _get_ticket(self, target): + """This function will check if we already have a SEK for the specified + target in the cache, or will go and try to fetch a new SEK from the key + server. + + :param target: The name of the target service + """ + ticket = self._key_store.get_ticket(self._name, target) + + if ticket is not None: + return ticket + + tkt, expiration = self._kds.get_ticket(self._name, target, + self._crypto, self._key) + + self._key_store.put_ticket(self._name, target, + base64.b64decode(tkt['skey']), + base64.b64decode(tkt['ekey']), + tkt['esek'], expiration) + return self._key_store.get_ticket(self._name, target) + + def encode(self, version, target, json_msg): + """This is the main encoding function. + + It takes a target and a message and returns a tuple consisting of a + JSON serialized metadata object, a JSON serialized (and optionally + encrypted) message, and a signature. + + :param version: the current envelope version + :param target: The name of the target service (usually with hostname) + :param json_msg: a serialized json message object + """ + ticket = self._get_ticket(target) + + metadata = jsonutils.dumps({'source': self._name, + 'destination': target, + 'timestamp': time.time(), + 'nonce': _get_nonce(), + 'esek': ticket.esek, + 'encryption': self._encrypt}) + + message = json_msg + if self._encrypt: + message = self._crypto.encrypt(ticket.ekey, message) + + signature = self._crypto.sign(ticket.skey, + version + metadata + message) + + return (metadata, message, signature) + + def decode(self, version, metadata, message, signature): + """This is the main decoding function. + + It takes a version, metadata, message and signature strings and + returns a tuple with a (decrypted) message and metadata or raises + an exception in case of error. + + :param version: the current envelope version + :param metadata: a JSON serialized object with metadata for validation + :param message: a JSON serialized (base64 encoded encrypted) message + :param signature: a base64 encoded signature + """ + md = jsonutils.loads(metadata) + + check_args = ('source', 'destination', 'timestamp', + 'nonce', 'esek', 'encryption') + for arg in check_args: + if arg not in md: + raise InvalidMetadata('Missing metadata "%s"' % arg) + + if md['destination'] != self._name: + # TODO(simo) handle group keys by checking target + raise UnknownDestinationName(md['destination']) + + try: + skey, ekey = self._decode_esek(self._key, + md['source'], md['destination'], + md['timestamp'], md['esek']) + except InvalidExpiredTicket: + raise + except Exception: + raise InvalidMetadata('Failed to decode ESEK for %s/%s' % ( + md['source'], md['destination'])) + + sig = self._crypto.sign(skey, version + metadata + message) + + if sig != signature: + raise InvalidSignature(md['source'], md['destination']) + + if md['encryption'] is True: + msg = self._crypto.decrypt(ekey, message) + else: + msg = message + + return (md, msg) diff --git a/barbican/openstack/common/rpc/serializer.py b/barbican/openstack/common/rpc/serializer.py new file mode 100644 index 000000000..76c683103 --- /dev/null +++ b/barbican/openstack/common/rpc/serializer.py @@ -0,0 +1,52 @@ +# Copyright 2013 IBM Corp. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Provides the definition of an RPC serialization handler""" + +import abc + + +class Serializer(object): + """Generic (de-)serialization definition base class.""" + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def serialize_entity(self, context, entity): + """Serialize something to primitive form. + + :param context: Security context + :param entity: Entity to be serialized + :returns: Serialized form of entity + """ + pass + + @abc.abstractmethod + def deserialize_entity(self, context, entity): + """Deserialize something from primitive form. + + :param context: Security context + :param entity: Primitive to be deserialized + :returns: Deserialized form of entity + """ + pass + + +class NoOpSerializer(Serializer): + """A serializer that does nothing.""" + + def serialize_entity(self, context, entity): + return entity + + def deserialize_entity(self, context, entity): + return entity diff --git a/barbican/openstack/common/rpc/service.py b/barbican/openstack/common/rpc/service.py index 7c239d444..0cec9e9cc 100644 --- a/barbican/openstack/common/rpc/service.py +++ b/barbican/openstack/common/rpc/service.py @@ -17,7 +17,7 @@ # License for the specific language governing permissions and limitations # under the License. -from barbican.openstack.common.gettextutils import _ +from barbican.openstack.common.gettextutils import _ # noqa from barbican.openstack.common import log as logging from barbican.openstack.common import rpc from barbican.openstack.common.rpc import dispatcher as rpc_dispatcher @@ -30,11 +30,13 @@ LOG = logging.getLogger(__name__) class Service(service.Service): """Service object for binaries running on hosts. - A service enables rpc by listening to queues based on topic and host.""" - def __init__(self, host, topic, manager=None): + A service enables rpc by listening to queues based on topic and host. + """ + def __init__(self, host, topic, manager=None, serializer=None): super(Service, self).__init__() self.host = host self.topic = topic + self.serializer = serializer if manager is None: self.manager = self else: @@ -47,7 +49,8 @@ class Service(service.Service): LOG.debug(_("Creating Consumer connection for Service %s") % self.topic) - dispatcher = rpc_dispatcher.RpcDispatcher([self.manager]) + dispatcher = rpc_dispatcher.RpcDispatcher([self.manager], + self.serializer) # Share this same connection for these Consumers self.conn.create_consumer(self.topic, dispatcher, fanout=False) diff --git a/barbican/openstack/common/service.py b/barbican/openstack/common/service.py index 57d8e2a30..948f52ad5 100644 --- a/barbican/openstack/common/service.py +++ b/barbican/openstack/common/service.py @@ -27,11 +27,12 @@ import sys import time import eventlet +from eventlet import event import logging as std_logging from oslo.config import cfg from barbican.openstack.common import eventlet_backdoor -from barbican.openstack.common.gettextutils import _ +from barbican.openstack.common.gettextutils import _ # noqa from barbican.openstack.common import importutils from barbican.openstack.common import log as logging from barbican.openstack.common import threadgroup @@ -51,19 +52,8 @@ class Launcher(object): :returns: None """ - self._services = threadgroup.ThreadGroup() - eventlet_backdoor.initialize_if_enabled() - - @staticmethod - def run_service(service): - """Start and wait for a service to finish. - - :param service: service to run and wait for. - :returns: None - - """ - service.start() - service.wait() + self.services = Services() + self.backdoor_port = eventlet_backdoor.initialize_if_enabled() def launch_service(self, service): """Load and start the given service. @@ -72,7 +62,8 @@ class Launcher(object): :returns: None """ - self._services.add_thread(self.run_service, service) + service.backdoor_port = self.backdoor_port + self.services.add(service) def stop(self): """Stop all services which are currently running. @@ -80,7 +71,7 @@ class Launcher(object): :returns: None """ - self._services.stop() + self.services.stop() def wait(self): """Waits until all services have been stopped, and then returns. @@ -88,7 +79,16 @@ class Launcher(object): :returns: None """ - self._services.wait() + self.services.wait() + + def restart(self): + """Reload config files and restart service. + + :returns: None + + """ + cfg.CONF.reload_config_files() + self.services.restart() class SignalExit(SystemExit): @@ -102,31 +102,51 @@ class ServiceLauncher(Launcher): # Allow the process to be killed again and die from natural causes signal.signal(signal.SIGTERM, signal.SIG_DFL) signal.signal(signal.SIGINT, signal.SIG_DFL) + signal.signal(signal.SIGHUP, signal.SIG_DFL) raise SignalExit(signo) - def wait(self): + def handle_signal(self): signal.signal(signal.SIGTERM, self._handle_signal) signal.signal(signal.SIGINT, self._handle_signal) + signal.signal(signal.SIGHUP, self._handle_signal) + + def _wait_for_exit_or_signal(self): + status = None + signo = 0 LOG.debug(_('Full set of CONF:')) CONF.log_opt_values(LOG, std_logging.DEBUG) - status = None try: super(ServiceLauncher, self).wait() except SignalExit as exc: signame = {signal.SIGTERM: 'SIGTERM', - signal.SIGINT: 'SIGINT'}[exc.signo] + signal.SIGINT: 'SIGINT', + signal.SIGHUP: 'SIGHUP'}[exc.signo] LOG.info(_('Caught %s, exiting'), signame) status = exc.code + signo = exc.signo except SystemExit as exc: status = exc.code finally: - if rpc: - rpc.cleanup() self.stop() - return status + if rpc: + try: + rpc.cleanup() + except Exception: + # We're shutting down, so it doesn't matter at this point. + LOG.exception(_('Exception during rpc cleanup.')) + + return status, signo + + def wait(self): + while True: + self.handle_signal() + status, signo = self._wait_for_exit_or_signal() + if signo != signal.SIGHUP: + return status + self.restart() class ServiceWrapper(object): @@ -144,9 +164,12 @@ class ProcessLauncher(object): self.running = True rfd, self.writepipe = os.pipe() self.readpipe = eventlet.greenio.GreenPipe(rfd, 'r') + self.handle_signal() + def handle_signal(self): signal.signal(signal.SIGTERM, self._handle_signal) signal.signal(signal.SIGINT, self._handle_signal) + signal.signal(signal.SIGHUP, self._handle_signal) def _handle_signal(self, signo, frame): self.sigcaught = signo @@ -155,6 +178,7 @@ class ProcessLauncher(object): # Allow the process to be killed again and die from natural causes signal.signal(signal.SIGTERM, signal.SIG_DFL) signal.signal(signal.SIGINT, signal.SIG_DFL) + signal.signal(signal.SIGHUP, signal.SIG_DFL) def _pipe_watcher(self): # This will block until the write end is closed when the parent @@ -165,16 +189,47 @@ class ProcessLauncher(object): sys.exit(1) - def _child_process(self, service): + def _child_process_handle_signal(self): # Setup child signal handlers differently def _sigterm(*args): signal.signal(signal.SIGTERM, signal.SIG_DFL) raise SignalExit(signal.SIGTERM) + def _sighup(*args): + signal.signal(signal.SIGHUP, signal.SIG_DFL) + raise SignalExit(signal.SIGHUP) + signal.signal(signal.SIGTERM, _sigterm) + signal.signal(signal.SIGHUP, _sighup) # Block SIGINT and let the parent send us a SIGTERM signal.signal(signal.SIGINT, signal.SIG_IGN) + def _child_wait_for_exit_or_signal(self, launcher): + status = None + signo = 0 + + try: + launcher.wait() + except SignalExit as exc: + signame = {signal.SIGTERM: 'SIGTERM', + signal.SIGINT: 'SIGINT', + signal.SIGHUP: 'SIGHUP'}[exc.signo] + LOG.info(_('Caught %s, exiting'), signame) + status = exc.code + signo = exc.signo + except SystemExit as exc: + status = exc.code + except BaseException: + LOG.exception(_('Unhandled exception')) + status = 2 + finally: + launcher.stop() + + return status, signo + + def _child_process(self, service): + self._child_process_handle_signal() + # Reopen the eventlet hub to make sure we don't share an epoll # fd with parent and/or siblings, which would be bad eventlet.hubs.use_hub() @@ -188,7 +243,8 @@ class ProcessLauncher(object): random.seed() launcher = Launcher() - launcher.run_service(service) + launcher.launch_service(service) + return launcher def _start_child(self, wrap): if len(wrap.forktimes) > wrap.workers: @@ -209,21 +265,13 @@ class ProcessLauncher(object): # NOTE(johannes): All exceptions are caught to ensure this # doesn't fallback into the loop spawning children. It would # be bad for a child to spawn more children. - status = 0 - try: - self._child_process(wrap.service) - except SignalExit as exc: - signame = {signal.SIGTERM: 'SIGTERM', - signal.SIGINT: 'SIGINT'}[exc.signo] - LOG.info(_('Caught %s, exiting'), signame) - status = exc.code - except SystemExit as exc: - status = exc.code - except BaseException: - LOG.exception(_('Unhandled exception')) - status = 2 - finally: - wrap.service.stop() + launcher = self._child_process(wrap.service) + while True: + self._child_process_handle_signal() + status, signo = self._child_wait_for_exit_or_signal(launcher) + if signo != signal.SIGHUP: + break + launcher.restart() os._exit(status) @@ -269,12 +317,7 @@ class ProcessLauncher(object): wrap.children.remove(pid) return wrap - def wait(self): - """Loop waiting on children to die and respawning as necessary""" - - LOG.debug(_('Full set of CONF:')) - CONF.log_opt_values(LOG, std_logging.DEBUG) - + def _respawn_children(self): while self.running: wrap = self._wait_child() if not wrap: @@ -283,14 +326,30 @@ class ProcessLauncher(object): # (see bug #1095346) eventlet.greenthread.sleep(.01) continue - while self.running and len(wrap.children) < wrap.workers: self._start_child(wrap) - if self.sigcaught: - signame = {signal.SIGTERM: 'SIGTERM', - signal.SIGINT: 'SIGINT'}[self.sigcaught] - LOG.info(_('Caught %s, stopping children'), signame) + def wait(self): + """Loop waiting on children to die and respawning as necessary.""" + + LOG.debug(_('Full set of CONF:')) + CONF.log_opt_values(LOG, std_logging.DEBUG) + + while True: + self.handle_signal() + self._respawn_children() + if self.sigcaught: + signame = {signal.SIGTERM: 'SIGTERM', + signal.SIGINT: 'SIGINT', + signal.SIGHUP: 'SIGHUP'}[self.sigcaught] + LOG.info(_('Caught %s, stopping children'), signame) + if self.sigcaught != signal.SIGHUP: + break + + for pid in self.children: + os.kill(pid, signal.SIGHUP) + self.running = True + self.sigcaught = None for pid in self.children: try: @@ -312,15 +371,74 @@ class Service(object): def __init__(self, threads=1000): self.tg = threadgroup.ThreadGroup(threads) + # signal that the service is done shutting itself down: + self._done = event.Event() + + def reset(self): + # NOTE(Fengqian): docs for Event.reset() recommend against using it + self._done = event.Event() + def start(self): pass def stop(self): self.tg.stop() + self.tg.wait() + # Signal that service cleanup is done: + if not self._done.ready(): + self._done.send() + + def wait(self): + self._done.wait() + + +class Services(object): + + def __init__(self): + self.services = [] + self.tg = threadgroup.ThreadGroup() + self.done = event.Event() + + def add(self, service): + self.services.append(service) + self.tg.add_thread(self.run_service, service, self.done) + + def stop(self): + # wait for graceful shutdown of services: + for service in self.services: + service.stop() + service.wait() + + # Each service has performed cleanup, now signal that the run_service + # wrapper threads can now die: + if not self.done.ready(): + self.done.send() + + # reap threads: + self.tg.stop() def wait(self): self.tg.wait() + def restart(self): + self.stop() + self.done = event.Event() + for restart_service in self.services: + restart_service.reset() + self.tg.add_thread(self.run_service, restart_service, self.done) + + @staticmethod + def run_service(service, done): + """Service start wrapper. + + :param service: service to run + :param done: event to wait on until a shutdown is triggered + :returns: None + + """ + service.start() + done.wait() + def launch(service, workers=None): if workers: diff --git a/barbican/openstack/common/sslutils.py b/barbican/openstack/common/sslutils.py new file mode 100644 index 000000000..1d03532b3 --- /dev/null +++ b/barbican/openstack/common/sslutils.py @@ -0,0 +1,100 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 IBM Corp. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import ssl + +from oslo.config import cfg + +from barbican.openstack.common.gettextutils import _ # noqa + + +ssl_opts = [ + cfg.StrOpt('ca_file', + default=None, + help="CA certificate file to use to verify " + "connecting clients"), + cfg.StrOpt('cert_file', + default=None, + help="Certificate file to use when starting " + "the server securely"), + cfg.StrOpt('key_file', + default=None, + help="Private key file to use when starting " + "the server securely"), +] + + +CONF = cfg.CONF +CONF.register_opts(ssl_opts, "ssl") + + +def is_enabled(): + cert_file = CONF.ssl.cert_file + key_file = CONF.ssl.key_file + ca_file = CONF.ssl.ca_file + use_ssl = cert_file or key_file + + if cert_file and not os.path.exists(cert_file): + raise RuntimeError(_("Unable to find cert_file : %s") % cert_file) + + if ca_file and not os.path.exists(ca_file): + raise RuntimeError(_("Unable to find ca_file : %s") % ca_file) + + if key_file and not os.path.exists(key_file): + raise RuntimeError(_("Unable to find key_file : %s") % key_file) + + if use_ssl and (not cert_file or not key_file): + raise RuntimeError(_("When running server in SSL mode, you must " + "specify both a cert_file and key_file " + "option value in your configuration file")) + + return use_ssl + + +def wrap(sock): + ssl_kwargs = { + 'server_side': True, + 'certfile': CONF.ssl.cert_file, + 'keyfile': CONF.ssl.key_file, + 'cert_reqs': ssl.CERT_NONE, + } + + if CONF.ssl.ca_file: + ssl_kwargs['ca_certs'] = CONF.ssl.ca_file + ssl_kwargs['cert_reqs'] = ssl.CERT_REQUIRED + + return ssl.wrap_socket(sock, **ssl_kwargs) + + +_SSL_PROTOCOLS = { + "tlsv1": ssl.PROTOCOL_TLSv1, + "sslv23": ssl.PROTOCOL_SSLv23, + "sslv3": ssl.PROTOCOL_SSLv3 +} + +try: + _SSL_PROTOCOLS["sslv2"] = ssl.PROTOCOL_SSLv2 +except AttributeError: + pass + + +def validate_ssl_version(version): + key = version.lower() + try: + return _SSL_PROTOCOLS[key] + except KeyError: + raise RuntimeError(_("Invalid SSL version : %s") % version) diff --git a/barbican/openstack/common/threadgroup.py b/barbican/openstack/common/threadgroup.py index 1e5295a7c..9f25728c5 100644 --- a/barbican/openstack/common/threadgroup.py +++ b/barbican/openstack/common/threadgroup.py @@ -14,7 +14,7 @@ # License for the specific language governing permissions and limitations # under the License. -from eventlet import greenlet +import eventlet from eventlet import greenpool from eventlet import greenthread @@ -26,7 +26,7 @@ LOG = logging.getLogger(__name__) def _thread_done(gt, *args, **kwargs): - """ Callback function to be passed to GreenThread.link() when we spawn() + """Callback function to be passed to GreenThread.link() when we spawn() Calls the :class:`ThreadGroup` to notify if. """ @@ -34,7 +34,7 @@ def _thread_done(gt, *args, **kwargs): class Thread(object): - """ Wrapper around a greenthread, that holds a reference to the + """Wrapper around a greenthread, that holds a reference to the :class:`ThreadGroup`. The Thread will notify the :class:`ThreadGroup` when it has done so it can be removed from the threads list. """ @@ -50,7 +50,7 @@ class Thread(object): class ThreadGroup(object): - """ The point of the ThreadGroup classis to: + """The point of the ThreadGroup classis to: * keep track of timers and greenthreads (making it easier to stop them when need be). @@ -61,6 +61,13 @@ class ThreadGroup(object): self.threads = [] self.timers = [] + def add_dynamic_timer(self, callback, initial_delay=None, + periodic_interval_max=None, *args, **kwargs): + timer = loopingcall.DynamicLoopingCall(callback, *args, **kwargs) + timer.start(initial_delay=initial_delay, + periodic_interval_max=periodic_interval_max) + self.timers.append(timer) + def add_timer(self, interval, callback, initial_delay=None, *args, **kwargs): pulse = loopingcall.FixedIntervalLoopingCall(callback, *args, **kwargs) @@ -98,7 +105,7 @@ class ThreadGroup(object): for x in self.timers: try: x.wait() - except greenlet.GreenletExit: + except eventlet.greenlet.GreenletExit: pass except Exception as ex: LOG.exception(ex) @@ -108,7 +115,7 @@ class ThreadGroup(object): continue try: x.wait() - except greenlet.GreenletExit: + except eventlet.greenlet.GreenletExit: pass except Exception as ex: LOG.exception(ex) diff --git a/barbican/openstack/common/timeutils.py b/barbican/openstack/common/timeutils.py index 609436590..aa9f70807 100644 --- a/barbican/openstack/common/timeutils.py +++ b/barbican/openstack/common/timeutils.py @@ -23,6 +23,7 @@ import calendar import datetime import iso8601 +import six # ISO 8601 extended time format with microseconds @@ -32,7 +33,7 @@ PERFECT_TIME_FORMAT = _ISO8601_TIME_FORMAT_SUBSECOND def isotime(at=None, subsecond=False): - """Stringify time in ISO 8601 format""" + """Stringify time in ISO 8601 format.""" if not at: at = utcnow() st = at.strftime(_ISO8601_TIME_FORMAT @@ -44,13 +45,13 @@ def isotime(at=None, subsecond=False): def parse_isotime(timestr): - """Parse time from ISO 8601 format""" + """Parse time from ISO 8601 format.""" try: return iso8601.parse_date(timestr) except iso8601.ParseError as e: - raise ValueError(e.message) + raise ValueError(unicode(e)) except TypeError as e: - raise ValueError(e.message) + raise ValueError(unicode(e)) def strtime(at=None, fmt=PERFECT_TIME_FORMAT): @@ -66,7 +67,7 @@ def parse_strtime(timestr, fmt=PERFECT_TIME_FORMAT): def normalize_time(timestamp): - """Normalize time in arbitrary timezone to UTC naive object""" + """Normalize time in arbitrary timezone to UTC naive object.""" offset = timestamp.utcoffset() if offset is None: return timestamp @@ -75,14 +76,14 @@ def normalize_time(timestamp): def is_older_than(before, seconds): """Return True if before is older than seconds.""" - if isinstance(before, basestring): + if isinstance(before, six.string_types): before = parse_strtime(before).replace(tzinfo=None) return utcnow() - before > datetime.timedelta(seconds=seconds) def is_newer_than(after, seconds): """Return True if after is newer than seconds.""" - if isinstance(after, basestring): + if isinstance(after, six.string_types): after = parse_strtime(after).replace(tzinfo=None) return after - utcnow() > datetime.timedelta(seconds=seconds) @@ -103,7 +104,7 @@ def utcnow(): def iso8601_from_timestamp(timestamp): - """Returns a iso8601 formated date from timestamp""" + """Returns a iso8601 formated date from timestamp.""" return isotime(datetime.datetime.utcfromtimestamp(timestamp)) @@ -111,9 +112,9 @@ utcnow.override_time = None def set_time_override(override_time=datetime.datetime.utcnow()): - """ - Override utils.utcnow to return a constant time or a list thereof, - one at a time. + """Overrides utils.utcnow. + + Make it return a constant time or a list thereof, one at a time. """ utcnow.override_time = override_time @@ -141,7 +142,8 @@ def clear_time_override(): def marshall_now(now=None): """Make an rpc-safe datetime with microseconds. - Note: tzinfo is stripped, but not required for relative times.""" + Note: tzinfo is stripped, but not required for relative times. + """ if not now: now = utcnow() return dict(day=now.day, month=now.month, year=now.year, hour=now.hour, @@ -161,7 +163,8 @@ def unmarshall_time(tyme): def delta_seconds(before, after): - """ + """Return the difference between two timing objects. + Compute the difference in seconds between two date, time, or datetime objects (as a float, to microsecond resolution). """ @@ -174,8 +177,7 @@ def delta_seconds(before, after): def is_soon(dt, window): - """ - Determines if time is going to happen in the next window seconds. + """Determines if time is going to happen in the next window seconds. :params dt: the time :params window: minimum seconds to remain to consider the time not soon diff --git a/barbican/tests/api/resources_test.py b/barbican/tests/api/resources_test.py index 22dc425d0..119fdb404 100644 --- a/barbican/tests/api/resources_test.py +++ b/barbican/tests/api/resources_test.py @@ -13,6 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +This test module focuses on typical-flow business logic tests with the API +resource classes. For RBAC tests of these classes, see the +'resources_policy_test.py' module. +""" + import base64 import json import unittest @@ -85,12 +91,9 @@ def validate_datum(test, datum): class WhenTestingVersionResource(unittest.TestCase): def setUp(self): - self.policy = MagicMock() - self.req = MagicMock() self.resp = MagicMock() - self.policy = MagicMock() - self.resource = res.VersionResource(self.policy) + self.resource = res.VersionResource() def test_should_return_200_on_get(self): self.resource.on_get(self.req, self.resp) @@ -146,8 +149,6 @@ class WhenCreatingSecretsUsingSecretsResource(unittest.TestCase): self.kek_repo = MagicMock() self.kek_repo.find_or_create_kek_metadata.return_value = self.kek_datum - self.policy = MagicMock() - self.stream = MagicMock() self.stream.read.return_value = self.json @@ -165,8 +166,7 @@ class WhenCreatingSecretsUsingSecretsResource(unittest.TestCase): self.secret_repo, self.tenant_secret_repo, self.datum_repo, - self.kek_repo, - self.policy) + self.kek_repo) def test_should_add_new_secret(self): self.resource.on_post(self.req, self.resp, self.keystone_id) @@ -414,9 +414,6 @@ class WhenGettingSecretsListUsingSecretsResource(unittest.TestCase): self.kek_repo = MagicMock() - self.policy = MagicMock() - self.policy.read.return_value = None - self.conf = MagicMock() self.conf.crypto.namespace = 'barbican.test.crypto.plugin' self.conf.crypto.enabled_crypto_plugins = ['test_crypto'] @@ -431,8 +428,7 @@ class WhenGettingSecretsListUsingSecretsResource(unittest.TestCase): self.secret_repo, self.tenant_secret_repo, self.datum_repo, - self.kek_repo, - self.policy) + self.kek_repo) def test_should_get_list_secrets(self): self.resource.on_get(self.req, self.resp, self.keystone_id) @@ -541,8 +537,6 @@ class WhenGettingPuttingOrDeletingSecretUsingSecretResource(unittest.TestCase): self.kek_repo = MagicMock() - self.policy = MagicMock() - self.req = MagicMock() self.req.accept = 'application/json' self.resp = MagicMock() @@ -552,14 +546,12 @@ class WhenGettingPuttingOrDeletingSecretUsingSecretResource(unittest.TestCase): self.conf.crypto.enabled_crypto_plugins = ['test_crypto'] self.crypto_mgr = CryptoExtensionManager(conf=self.conf) - self.policy = MagicMock() self.resource = res.SecretResource(self.crypto_mgr, self.tenant_repo, self.secret_repo, self.tenant_secret_repo, self.datum_repo, - self.kek_repo, - self.policy) + self.kek_repo) def test_should_get_secret_as_json(self): self.resource.on_get(self.req, self.resp, self.keystone_id, @@ -845,8 +837,6 @@ class WhenCreatingOrdersUsingOrdersResource(unittest.TestCase): self.queue_resource = MagicMock() self.queue_resource.process_order.return_value = None - self.policy = MagicMock() - self.stream = MagicMock() order_req = {'secret': {'name': self.secret_name, @@ -862,9 +852,8 @@ class WhenCreatingOrdersUsingOrdersResource(unittest.TestCase): self.req.stream = self.stream self.resp = MagicMock() - self.policy = MagicMock() self.resource = res.OrdersResource(self.tenant_repo, self.order_repo, - self.queue_resource, self.policy) + self.queue_resource) def test_should_add_new_order(self): self.resource.on_post(self.req, self.resp, self.tenant_keystone_id) @@ -931,14 +920,12 @@ class WhenGettingOrdersListUsingOrdersResource(unittest.TestCase): self.queue_resource = MagicMock() self.queue_resource.process_order.return_value = None - self.policy = MagicMock() - self.req = MagicMock() self.req.accept = 'application/json' self.req._params = self.params self.resp = MagicMock() self.resource = res.OrdersResource(self.tenant_repo, self.order_repo, - self.queue_resource, self.policy) + self.queue_resource) def test_should_get_list_orders(self): self.resource.on_get(self.req, self.resp, self.keystone_id) @@ -1007,14 +994,10 @@ class WhenGettingOrDeletingOrderUsingOrderResource(unittest.TestCase): self.order_repo.get.return_value = self.order self.order_repo.delete_entity_by_id.return_value = None - self.policy = MagicMock() - self.req = MagicMock() self.resp = MagicMock() - self.policy = MagicMock() - - self.resource = res.OrderResource(self.order_repo, self.policy) + self.resource = res.OrderResource(self.order_repo) def test_should_get_order(self): self.resource.on_get(self.req, self.resp, self.tenant_keystone_id, diff --git a/barbican/tests/api/test_resources_policy.py b/barbican/tests/api/test_resources_policy.py new file mode 100644 index 000000000..258d6940a --- /dev/null +++ b/barbican/tests/api/test_resources_policy.py @@ -0,0 +1,376 @@ +# Copyright (c) 2013 Rackspace, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This test module focuses on RBAC interactions with the API resource classes. +For typical-flow business logic tests of these classes, see the +'resources_test.py' module. +""" + +import os +import unittest + +import falcon +import mock +from oslo.config import cfg + +from barbican.api import resources as res +from barbican import context +from barbican.openstack.common import policy + + +CONF = cfg.CONF + +# Point to the policy.json file located in source control. +TEST_VAR_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), + '../../../etc', 'barbican')) + +ENFORCER = policy.Enforcer() + + +class BaseTestCase(unittest.TestCase): + + def setUp(self): + CONF(args=['--config-dir', TEST_VAR_DIR]) + self.policy_enforcer = ENFORCER + self.policy_enforcer.load_rules(True) + self.resp = mock.MagicMock() + + def _generate_req(self, roles=None, accept=None): + """Generate a fake HTTP request with security context added to it.""" + req = mock.MagicMock() + + kwargs = { + 'user': None, + 'tenant': None, + 'roles': roles or [], + 'policy_enforcer': self.policy_enforcer, + } + req.env = {} + req.env['barbican.context'] = context.RequestContext(**kwargs) + req.accept = accept + + return req + + def _generate_stream_for_exit(self): + """Mock HTTP stream generator, to force RBAC-pass exit. + + Generate a fake HTTP request stream that forces an IOError to + occur, which short circuits API resource processing when RBAC + checks under test here pass. + """ + stream = mock.MagicMock() + read = mock.MagicMock(return_value=None, side_effect=IOError()) + stream.read = read + return stream + + def _assert_post_rbac_exception(self, exception, role): + """Assert that we received the expected RBAC-passed exception.""" + fail_msg = "Expected RBAC pass on role '{0}'".format(role) + self.assertEqual(falcon.HTTP_500, exception.status, msg=fail_msg) + self.assertEquals('Read Error', exception.title, msg=fail_msg) + + def _generate_get_error(self): + """Falcon exception generator to throw from early-exit mocks. + + Creates an exception that should be raised by GET tests that pass + RBAC. This allows such flows to short-circuit normal post-RBAC + processing that is not tested in this module. + + :return: Python exception that should be raised by repo get methods. + """ + # The 'Read Error' clause needs to match that asserted in + # _assert_post_rbac_exception() above. + return falcon.HTTPError(falcon.HTTP_500, 'Read Error') + + def _assert_pass_rbac(self, roles, method_under_test, accept=None): + """Assert that RBAC authorization rules passed for the specified roles. + + :param roles: List of roles to check, one at a time + :param method_under_test: The test method to invoke for each role. + :param accept Optional Accept header to set on the HTTP request + :return: None + """ + for role in roles: + self.req = self._generate_req(roles=[role] if role else [], + accept=accept) + + # Force an exception early past the RBAC passing. + self.req.stream = self._generate_stream_for_exit() + with self.assertRaises(falcon.HTTPError) as cm: + method_under_test() + self._assert_post_rbac_exception(cm.exception, role) + + self.setUp() # Need to re-setup + + def _assert_fail_rbac(self, roles, method_under_test, accept=None): + """Assert that RBAC rules failed for one of the specified roles. + + :param roles: List of roles to check, one at a time + :param method_under_test: The test method to invoke for each role. + :param accept Optional Accept header to set on the HTTP request + :return: None + """ + for role in roles: + self.req = self._generate_req(roles=[role] if role else [], + accept=accept) + + with self.assertRaises(falcon.HTTPError) as cm: + method_under_test() + + exception = cm.exception + self.assertEqual(falcon.HTTP_401, exception.status, + msg="Expected RBAC fail for role '{0}'".format( + role)) + + self.setUp() # Need to re-setup + + +class WhenTestingVersionResource(BaseTestCase): + """RBAC tests for the barbican.api.resources.VersionResource class.""" + def setUp(self): + super(WhenTestingVersionResource, self).setUp() + + self.resource = res.VersionResource() + + def test_rules_should_be_loaded(self): + self.assertIsNotNone(self.policy_enforcer.rules) + + def test_should_pass_get_version(self): + # Can't use base method that short circuits post-RBAC processing here, + # as version GET is trivial + for role in ['admin', 'observer', 'creator', 'audit']: + self.req = self._generate_req(roles=[role] if role else []) + self._invoke_on_get() + self.setUp() # Need to re-setup + + def test_should_fail_get_version(self): + self._assert_fail_rbac([None, 'bunkrolehere'], self._invoke_on_get) + + def test_should_pass_get_version_multiple_roles(self): + self.req = self._generate_req(roles=['admin', 'observer', 'creator', + 'audit']) + self._invoke_on_get() + + def _invoke_on_get(self): + self.resource.on_get(self.req, self.resp) + + +class WhenTestingSecretsResource(BaseTestCase): + """RBAC tests for the barbican.api.resources.SecretsResource class.""" + def setUp(self): + super(WhenTestingSecretsResource, self).setUp() + + self.keystone_id = '12345' + + # Force an error on GET calls that pass RBAC, as we are not testing + # such flows in this test module. + self.secret_repo = mock.MagicMock() + get_by_create_date = mock.MagicMock(return_value=None, + side_effect=self + ._generate_get_error()) + self.secret_repo.get_by_create_date = get_by_create_date + + self.resource = res.SecretsResource(crypto_manager=mock.MagicMock(), + tenant_repo=mock.MagicMock(), + secret_repo=self.secret_repo, + tenant_secret_repo=mock + .MagicMock(), + datum_repo=mock.MagicMock(), + kek_repo=mock.MagicMock()) + + def test_rules_should_be_loaded(self): + self.assertIsNotNone(self.policy_enforcer.rules) + + def test_should_pass_create_secret(self): + self._assert_pass_rbac(['admin', 'creator'], self._invoke_on_post) + + def test_should_fail_create_secret(self): + self._assert_fail_rbac([None, 'audit', 'observer', 'bogus'], + self._invoke_on_post) + + def test_should_pass_get_secrets(self): + self._assert_pass_rbac(['admin', 'observer', 'creator'], + self._invoke_on_get) + + def test_should_fail_get_secrets(self): + self._assert_fail_rbac([None, 'audit', 'bogus'], + self._invoke_on_get) + + def _invoke_on_post(self): + self.resource.on_post(self.req, self.resp, self.keystone_id) + + def _invoke_on_get(self): + self.resource.on_get(self.req, self.resp, self.keystone_id) + + +class WhenTestingSecretResource(BaseTestCase): + """RBAC tests for the barbican.api.resources.SecretResource class.""" + def setUp(self): + super(WhenTestingSecretResource, self).setUp() + + self.keystone_id = '12345tenant' + self.secret_id = '12345secret' + + # Force an error on GET and DELETE calls that pass RBAC, + # as we are not testing such flows in this test module. + self.secret_repo = mock.MagicMock() + fail_method = mock.MagicMock(return_value=None, + side_effect=self._generate_get_error()) + self.secret_repo.get = fail_method + self.secret_repo.delete_entity_by_id = fail_method + + self.resource = res.SecretResource(crypto_manager=mock.MagicMock(), + tenant_repo=mock.MagicMock(), + secret_repo=self.secret_repo, + tenant_secret_repo=mock.MagicMock(), + datum_repo=mock.MagicMock(), + kek_repo=mock.MagicMock()) + + def test_rules_should_be_loaded(self): + self.assertIsNotNone(self.policy_enforcer.rules) + + def test_should_pass_decrypt_secret(self): + self._assert_pass_rbac(['admin', 'observer', 'creator'], + self._invoke_on_get, + accept='notjsonaccepttype') + + def test_should_fail_decrypt_secret(self): + self._assert_fail_rbac([None, 'audit', 'bogus'], + self._invoke_on_get, + accept='notjsonaccepttype') + + def test_should_pass_get_secret(self): + self._assert_pass_rbac(['admin', 'observer', 'creator', 'audit'], + self._invoke_on_get) + + def test_should_fail_get_secret(self): + self._assert_fail_rbac([None, 'bogus'], + self._invoke_on_get) + + def test_should_pass_put_secret(self): + self._assert_pass_rbac(['admin', 'creator'], self._invoke_on_put) + + def test_should_fail_put_secret(self): + self._assert_fail_rbac([None, 'audit', 'observer', 'bogus'], + self._invoke_on_put) + + def test_should_pass_delete_secret(self): + self._assert_pass_rbac(['admin'], self._invoke_on_delete) + + def test_should_fail_delete_secret(self): + self._assert_fail_rbac([None, 'audit', 'observer', 'creator', 'bogus'], + self._invoke_on_delete) + + def _invoke_on_get(self): + self.resource.on_get(self.req, self.resp, + self.keystone_id, self.secret_id) + + def _invoke_on_put(self): + self.resource.on_put(self.req, self.resp, + self.keystone_id, self.secret_id) + + def _invoke_on_delete(self): + self.resource.on_delete(self.req, self.resp, + self.keystone_id, self.secret_id) + + +class WhenTestingOrdersResource(BaseTestCase): + """RBAC tests for the barbican.api.resources.OrdersResource class.""" + def setUp(self): + super(WhenTestingOrdersResource, self).setUp() + + self.keystone_id = '12345' + + # Force an error on GET calls that pass RBAC, as we are not testing + # such flows in this test module. + self.order_repo = mock.MagicMock() + get_by_create_date = mock.MagicMock(return_value=None, + side_effect=self + ._generate_get_error()) + self.order_repo.get_by_create_date = get_by_create_date + + self.resource = res.OrdersResource(tenant_repo=mock.MagicMock(), + order_repo=self.order_repo, + queue_resource=mock.MagicMock()) + + def test_rules_should_be_loaded(self): + self.assertIsNotNone(self.policy_enforcer.rules) + + def test_should_pass_create_order(self): + self._assert_pass_rbac(['admin', 'creator'], self._invoke_on_post) + + def test_should_fail_create_order(self): + self._assert_fail_rbac([None, 'audit', 'observer', 'bogus'], + self._invoke_on_post) + + def test_should_pass_get_orders(self): + self._assert_pass_rbac(['admin', 'observer', 'creator'], + self._invoke_on_get) + + def test_should_fail_get_orders(self): + self._assert_fail_rbac([None, 'audit', 'bogus'], + self._invoke_on_get) + + def _invoke_on_post(self): + self.resource.on_post(self.req, self.resp, self.keystone_id) + + def _invoke_on_get(self): + self.resource.on_get(self.req, self.resp, self.keystone_id) + + +class WhenTestingOrderResource(BaseTestCase): + """RBAC tests for the barbican.api.resources.OrderResource class.""" + def setUp(self): + super(WhenTestingOrderResource, self).setUp() + + self.keystone_id = '12345tenant' + self.order_id = '12345order' + + # Force an error on GET and DELETE calls that pass RBAC, + # as we are not testing such flows in this test module. + self.order_repo = mock.MagicMock() + fail_method = mock.MagicMock(return_value=None, + side_effect=self._generate_get_error()) + self.order_repo.get = fail_method + self.order_repo.delete_entity_by_id = fail_method + + self.resource = res.OrderResource(order_repo=self.order_repo) + + def test_rules_should_be_loaded(self): + self.assertIsNotNone(self.policy_enforcer.rules) + + def test_should_pass_get_order(self): + self._assert_pass_rbac(['admin', 'observer', 'creator', 'audit'], + self._invoke_on_get) + + def test_should_fail_get_order(self): + self._assert_fail_rbac([None, 'bogus'], + self._invoke_on_get) + + def test_should_pass_delete_order(self): + self._assert_pass_rbac(['admin'], self._invoke_on_delete) + + def test_should_fail_delete_order(self): + self._assert_fail_rbac([None, 'audit', 'observer', 'creator', 'bogus'], + self._invoke_on_delete) + + def _invoke_on_get(self): + self.resource.on_get(self.req, self.resp, + self.keystone_id, self.order_id) + + def _invoke_on_delete(self): + self.resource.on_delete(self.req, self.resp, + self.keystone_id, self.order_id) diff --git a/etc/barbican/policy.json b/etc/barbican/policy.json index e53fd9d36..76122bb07 100644 --- a/etc/barbican/policy.json +++ b/etc/barbican/policy.json @@ -1,4 +1,22 @@ { - "default": "", - "manage_key_recycle": "role:admin" + "version:get": "rule:all_users", + "secret:decrypt": "rule:all_but_audit", + "secret:get": "rule:all_users", + "secret:put": "rule:admin_or_creator", + "secret:delete": "rule:admin", + "secrets:post": "rule:admin_or_creator", + "secrets:get": "rule:all_but_audit", + "orders:post": "rule:admin_or_creator", + "orders:get": "rule:all_but_audit", + "order:get": "rule:all_users", + "order:delete": "rule:admin", + "admin": ["role:admin"], + "observer": ["role:observer"], + "creator": ["role:creator"], + "audit": ["rule:audit"], + "admin_or_user_does_not_work": ["project_id:%(project_id)s"], + "admin_or_user": ["role:admin", "project_id:%(project_id)s"], + "admin_or_creator": ["role:admin", "role:creator"], + "all_but_audit": ["role:admin", "role:observer", "role:creator"], + "all_users": ["role:admin", "role:observer", "role:creator", "role:audit"] } \ No newline at end of file diff --git a/openstack-common.conf b/openstack-common.conf index ef03426c3..579c2ce8c 100644 --- a/openstack-common.conf +++ b/openstack-common.conf @@ -1,7 +1,7 @@ [DEFAULT] # The list of modules to copy from openstack-common -modules=gettextutils,jsonutils,log,local,notifier,timeutils,uuidutils,importutils +modules=gettextutils,jsonutils,log,local,notifier,timeutils,uuidutils,importutils,policy # The base module to hold the copy of openstack.common base=barbican diff --git a/tools/pip-requires b/tools/pip-requires index 0c3c30ae1..ba0879fbf 100644 --- a/tools/pip-requires +++ b/tools/pip-requires @@ -9,7 +9,7 @@ kombu>=2.5.9 webob>=1.2.3 PasteDeploy>=1.5.0 Celery>=3.0.19 -python-keystoneclient>=0.2.0 +python-keystoneclient>=0.3.1 stevedore>=0.8 pycrypto>=2.6 python-dateutil>=2.1 @@ -17,4 +17,6 @@ jsonschema>=2.0.0 SQLAlchemy>=0.8.1 alembic>=0.5.0 psycopg2>=2.5.1 +netaddr +Babel>=0.9.6 # TODO: Get this working again...PyKCS11>=1.2.4