diff --git a/terracotta/api/app.py b/terracotta/api/app.py index 7c5e9d2..619ab53 100644 --- a/terracotta/api/app.py +++ b/terracotta/api/app.py @@ -16,6 +16,7 @@ from oslo_config import cfg import pecan from terracotta.api import access_control +from terracotta import context as ctx def get_pecan_config(): diff --git a/terracotta/context.py b/terracotta/context.py new file mode 100644 index 0000000..0091ec7 --- /dev/null +++ b/terracotta/context.py @@ -0,0 +1,218 @@ +# Copyright 2016 Huawei Technologies Co. Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import eventlet +from keystoneclient.v3 import client as keystone_client +from oslo_config import cfg +import oslo_messaging as messaging +from oslo_serialization import jsonutils +import pecan +from pecan import hooks + +from terracotta import exceptions as exc +from terracotta import utils + + +CONF = cfg.CONF + +_CTX_THREAD_LOCAL_NAME = "TERRACOTTA_APP_CTX_THREAD_LOCAL" +ALLOWED_WITHOUT_AUTH = ['/'] + + +class BaseContext(object): + """Container for context variables.""" + + _elements = set() + + def __init__(self, __mapping=None, **kwargs): + if __mapping is None: + self.__values = dict(**kwargs) + else: + if isinstance(__mapping, BaseContext): + __mapping = __mapping.__values + self.__values = dict(__mapping) + self.__values.update(**kwargs) + + bad_keys = set(self.__values) - self._elements + + if bad_keys: + raise TypeError("Only %s keys are supported. %s given" % + (tuple(self._elements), tuple(bad_keys))) + + def __getattr__(self, name): + try: + return self.__values[name] + except KeyError: + if name in self._elements: + return None + else: + raise AttributeError(name) + + def to_dict(self): + return self.__values + + +class TerracottaContext(BaseContext): + # Use set([...]) since set literals are not supported in Python 2.6. + _elements = set([ + "user_id", + "project_id", + "auth_token", + "service_catalog", + "user_name", + "project_name", + "roles", + "is_admin", + "is_trust_scoped", + ]) + + def __repr__(self): + return "TerracottaContext %s" % self.to_dict() + + +def has_ctx(): + return utils.has_thread_local(_CTX_THREAD_LOCAL_NAME) + + +def ctx(): + if not has_ctx(): + raise exc.ApplicationContextNotFoundException() + + return utils.get_thread_local(_CTX_THREAD_LOCAL_NAME) + + +def set_ctx(new_ctx): + utils.set_thread_local(_CTX_THREAD_LOCAL_NAME, new_ctx) + + +def _wrapper(context, thread_desc, thread_group, func, *args, **kwargs): + try: + set_ctx(context) + func(*args, **kwargs) + except Exception as e: + if thread_group and not thread_group.exc: + thread_group.exc = e + thread_group.failed_thread = thread_desc + finally: + if thread_group: + thread_group._on_thread_exit() + + set_ctx(None) + + +def spawn(thread_description, func, *args, **kwargs): + eventlet.spawn(_wrapper, ctx().clone(), thread_description, + None, func, *args, **kwargs) + + +def context_from_headers(headers): + return TerracottaContext( + user_id=headers.get('X-User-Id'), + project_id=headers.get('X-Project-Id'), + auth_token=headers.get('X-Auth-Token'), + service_catalog=headers.get('X-Service-Catalog'), + user_name=headers.get('X-User-Name'), + project_name=headers.get('X-Project-Name'), + roles=headers.get('X-Roles', "").split(","), + is_trust_scoped=False, + ) + + +def context_from_config(): + keystone = keystone_client.Client( + username=CONF.keystone_authtoken.admin_user, + password=CONF.keystone_authtoken.admin_password, + tenant_name=CONF.keystone_authtoken.admin_tenant_name, + auth_url=CONF.keystone_authtoken.auth_uri, + is_trust_scoped=False, + ) + + keystone.authenticate() + + return TerracottaContext( + user_id=keystone.user_id, + project_id=keystone.project_id, + auth_token=keystone.auth_token, + project_name=CONF.keystone_authtoken.admin_tenant_name, + user_name=CONF.keystone_authtoken.admin_user, + is_trust_scoped=False, + ) + + +class JsonPayloadSerializer(messaging.NoOpSerializer): + @staticmethod + def serialize_entity(context, entity): + return jsonutils.to_primitive(entity, convert_instances=True) + + +class RpcContextSerializer(messaging.Serializer): + def __init__(self, base=None): + self._base = base or messaging.NoOpSerializer() + + def serialize_entity(self, context, entity): + if not self._base: + return entity + + return self._base.serialize_entity(context, entity) + + def deserialize_entity(self, context, entity): + if not self._base: + return entity + + return self._base.deserialize_entity(context, entity) + + def serialize_context(self, context): + return context.to_dict() + + def deserialize_context(self, context): + ctx = TerracottaContext(**context) + set_ctx(ctx) + + return ctx + + +class AuthHook(hooks.PecanHook): + def before(self, state): + if state.request.path in ALLOWED_WITHOUT_AUTH: + return + + if CONF.pecan.auth_enable: + identity_status = state.request.headers.get('X-Identity-Status') + service_identity_status = state.request.headers.get( + 'X-Service-Identity-Status' + ) + + if (identity_status == 'Confirmed' + or service_identity_status == 'Confirmed'): + return + + if state.request.headers.get('X-Auth-Token'): + msg = ("Auth token is invalid: %s" + % state.request.headers['X-Auth-Token']) + else: + msg = 'Authentication required' + + pecan.abort( + status_code=401, + detail=msg, + headers={'Server-Error-Message': msg} + ) + + +class ContextHook(hooks.PecanHook): + def before(self, state): + set_ctx(context_from_headers(state.request.headers)) + + def after(self, state): + set_ctx(None) diff --git a/terracotta/utils/__init__.py b/terracotta/utils/__init__.py index e69de29..5e909ba 100644 --- a/terracotta/utils/__init__.py +++ b/terracotta/utils/__init__.py @@ -0,0 +1,329 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 - Huawei Technologies Co. Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import json +import logging +import os +from os import path +import shutil +import six +import socket +import tempfile +import threading +import uuid + +import eventlet +from eventlet import corolocal +from oslo_concurrency import processutils +import pkg_resources as pkg +import random + +from terracotta import exceptions as exc +from terracotta import version + + +# Thread local storage. +_th_loc_storage = threading.local() + + +def generate_unicode_uuid(): + return six.text_type(str(uuid.uuid4())) + + +def _get_greenlet_local_storage(): + greenlet_id = corolocal.get_ident() + + greenlet_locals = getattr(_th_loc_storage, "greenlet_locals", None) + + if not greenlet_locals: + greenlet_locals = {} + _th_loc_storage.greenlet_locals = greenlet_locals + + if greenlet_id in greenlet_locals: + return greenlet_locals[greenlet_id] + else: + return None + + +def has_thread_local(var_name): + gl_storage = _get_greenlet_local_storage() + return gl_storage and var_name in gl_storage + + +def get_thread_local(var_name): + if not has_thread_local(var_name): + return None + + return _get_greenlet_local_storage()[var_name] + + +def set_thread_local(var_name, val): + if not val and has_thread_local(var_name): + gl_storage = _get_greenlet_local_storage() + + # Delete variable from greenlet local storage. + if gl_storage: + del gl_storage[var_name] + + # Delete the entire greenlet local storage from thread local storage. + if gl_storage and len(gl_storage) == 0: + del _th_loc_storage.greenlet_locals[corolocal.get_ident()] + + if val: + gl_storage = _get_greenlet_local_storage() + if not gl_storage: + gl_storage = _th_loc_storage.greenlet_locals[ + corolocal.get_ident()] = {} + + gl_storage[var_name] = val + + +def log_exec(logger, level=logging.DEBUG): + """Decorator for logging function execution. + + By default, target function execution is logged with DEBUG level. + """ + + def _decorator(func): + def _logged(*args, **kw): + params_repr = ("[args=%s, kw=%s]" % (str(args), str(kw)) + if args or kw else "") + + func_repr = ("Called method [name=%s, doc='%s', params=%s]" % + (func.__name__, func.__doc__, params_repr)) + + logger.log(level, func_repr) + + return func(*args, **kw) + + _logged.__doc__ = func.__doc__ + + return _logged + + return _decorator + + +def merge_dicts(left, right, overwrite=True): + """Merges two dictionaries. + + Values of right dictionary recursively get merged into left dictionary. + :param left: Left dictionary. + :param right: Right dictionary. + :param overwrite: If False, left value will not be overwritten if exists. + """ + + if left is None: + return right + + if right is None: + return left + + for k, v in six.iteritems(right): + if k not in left: + left[k] = v + else: + left_v = left[k] + + if isinstance(left_v, dict) and isinstance(v, dict): + merge_dicts(left_v, v, overwrite=overwrite) + elif overwrite: + left[k] = v + + return left + + +def get_file_list(directory): + base_path = pkg.resource_filename( + version.version_info.package, + directory + ) + + return [path.join(base_path, f) for f in os.listdir(base_path) + if path.isfile(path.join(base_path, f))] + + +def cut(data, length=100): + if not data: + return data + + string = str(data) + + if len(string) > length: + return "%s..." % string[:length] + else: + return string + + +def iter_subclasses(cls, _seen=None): + """Generator over all subclasses of a given class in depth first order.""" + + if not isinstance(cls, type): + raise TypeError('iter_subclasses must be called with new-style class' + ', not %.100r' % cls) + _seen = _seen or set() + + try: + subs = cls.__subclasses__() + except TypeError: # fails only when cls is type + subs = cls.__subclasses__(cls) + + for sub in subs: + if sub not in _seen: + _seen.add(sub) + yield sub + for _sub in iter_subclasses(sub, _seen): + yield _sub + + +def random_sleep(limit=1): + """Sleeps for a random period of time not exceeding the given limit. + + Mostly intended to be used by tests to emulate race conditions. + + :param limit: Float number of seconds that a sleep period must not exceed. + """ + + seconds = random.Random().randint(0, limit * 1000) * 0.001 + + print("Sleep: %s sec..." % seconds) + + eventlet.sleep(seconds) + + +class NotDefined(object): + """This class is just a marker of input params without value.""" + + pass + + +def get_dict_from_string(input_string, delimiter=','): + if not input_string: + return {} + + raw_inputs = input_string.split(delimiter) + + inputs = [] + + for raw in raw_inputs: + input = raw.strip() + name_value = input.split('=') + + if len(name_value) > 1: + + try: + value = json.loads(name_value[1]) + except ValueError: + value = name_value[1] + + inputs += [{name_value[0]: value}] + else: + inputs += [name_value[0]] + + return get_input_dict(inputs) + + +def get_input_dict(inputs): + """Transform input list to dictionary. + + Ensure every input param has a default value(it will be a NotDefined + object if it's not provided). + """ + input_dict = {} + for x in inputs: + if isinstance(x, dict): + input_dict.update(x) + else: + # NOTE(xylan): we put a NotDefined class here as the value of + # param without value specified, to distinguish from the valid + # values such as None, ''(empty string), etc. + input_dict[x] = NotDefined + + return input_dict + + +def get_process_identifier(): + """Gets current running process identifier.""" + + return "%s_%s" % (socket.gethostname(), os.getpid()) + + +@contextlib.contextmanager +def tempdir(**kwargs): + argdict = kwargs.copy() + + if 'dir' not in argdict: + argdict['dir'] = '/tmp/' + + tmpdir = tempfile.mkdtemp(**argdict) + + try: + yield tmpdir + finally: + try: + shutil.rmtree(tmpdir) + except OSError as e: + raise exc.DataAccessException( + "Failed to delete temp dir %(dir)s (reason: %(reason)s)" % + {'dir': tmpdir, 'reason': e} + ) + + +def save_text_to(text, file_path, overwrite=False): + if os.path.exists(file_path) and not overwrite: + raise exc.DataAccessException( + "Cannot save data to file. File %s already exists." + ) + + with open(file_path, 'w') as f: + f.write(text) + + +def generate_key_pair(key_length=2048): + """Create RSA key pair with specified number of bits in key. + + Returns tuple of private and public keys. + """ + with tempdir() as tmpdir: + keyfile = os.path.join(tmpdir, 'tempkey') + args = [ + 'ssh-keygen', + '-q', # quiet + '-N', '', # w/o passphrase + '-t', 'rsa', # create key of rsa type + '-f', keyfile, # filename of the key file + '-C', 'Generated-by-Mistral' # key comment + ] + + if key_length is not None: + args.extend(['-b', key_length]) + + processutils.execute(*args) + + if not os.path.exists(keyfile): + raise exc.DataAccessException( + "Private key file hasn't been created" + ) + + private_key = open(keyfile).read() + public_key_path = keyfile + '.pub' + + if not os.path.exists(public_key_path): + raise exc.DataAccessException( + "Public key file hasn't been created" + ) + public_key = open(public_key_path).read() + + return private_key, public_key