Adding in oslo wsgi/pasteapp connection

This commit is contained in:
John Bresnahan 2013-05-16 13:08:01 -10:00
parent fd6e15bca8
commit 1188f5a0e7
31 changed files with 1549 additions and 105 deletions

View File

@ -1,5 +1,5 @@
[pipeline:staccato-api]
pipeline = unauthenticated-context rootapp
pipeline = rootapp
[app:rootapp]
use = egg:Paste#urlmap
@ -7,11 +7,9 @@ use = egg:Paste#urlmap
/v1: apiv1app
[app:apiversions]
paste.app_factory = staccato.api.versions:create_resource
paste.app_factory = staccato.openstack.common.pastedeploy:app_factory
openstack.app_factory = staccato.api.versions:VersionApp
[app:apiv1app]
paste.app_factory = staccato.api.v1.xfer:create_resource
[filter:unauthenticated-context]
paste.filter_factory = staccato.wsgi:Middleware.factory
paste.app_factory = staccato.openstack.common.pastedeploy:app_factory
openstack.app_factory = staccato.api.v1.xfer:XferApp

View File

@ -60,3 +60,6 @@ sql_idle_timeout = 3600
# Should be set to a random string of length 16, 24 or 32 bytes
#metadata_encryption_key = <16, 24 or 32 char registry metadata key>
[paste_deploy]
config_file = /home/jbresnah/Dev/OpenStack/staccato/etc/staccato-api-paste.ini
#

View File

@ -32,9 +32,3 @@ class XferApp(object):
content_type='application/json')
response.body = json.dumps(dict(versions=version_objs))
return response
def create_resource(conf):
# TODO: figure out what this has to be this way
config_obj = conf['CONF']['conf']
return XferApp(conf=config_obj)

View File

@ -2,8 +2,6 @@ import httplib
import json
import webob
from staccato.common import wsgi
class VersionApp(object):
"""
@ -12,10 +10,10 @@ class VersionApp(object):
def __init__(self, conf):
self.conf = conf
@webob.dec.wsgify(RequestClass=wsgi.Request)
@webob.dec.wsgify
def __call__(self, req):
version_info = {
'id': self.conf.id,
'id': self.conf.service_id,
'version': self.conf.version,
'status': 'active'
}
@ -26,9 +24,3 @@ class VersionApp(object):
content_type='application/json')
response.body = json.dumps(dict(versions=version_objs))
return response
def create_resource(conf):
# TODO: figure out what this has to be this way
config_obj = conf['CONF']['conf']
return VersionApp(conf=config_obj)

View File

@ -1,9 +1,11 @@
import eventlet
import gettext
import sys
import time
from staccato.common import utils
from staccato import wsgi
from staccato.common import config
import staccato.openstack.common.wsgi as os_wsgi
import staccato.openstack.common.pastedeploy as os_pastedeploy
# Monkey patch socket and time
eventlet.patcher.monkey_patch(all=False, socket=True, time=True)
@ -16,29 +18,14 @@ def fail(returncode, e):
sys.exit(returncode)
class CONF(object):
def __init__(self):
self.bind_host = "0.0.0.0"
self.bind_port = 9876
self.cert_file = None
self.key_file = None
self.backlog = -1
self.tcp_keepidle = False
self.id = "deadbeef"
self.version = "v1"
def main():
try:
#config.parse_args(sys.argv)
conf = CONF()
wsgi_app = utils.load_paste_app(
'staccato-api',
'/home/jbresnah/Dev/OpenStack/staccato/etc/glance-api-paste.ini',
{'conf': conf})
server = wsgi.Server(CONF=conf)
server.start(wsgi_app, default_port=9292)
conf = config.get_config_object()
paste_file = conf.paste_deploy.config_file
wsgi_app = os_pastedeploy.paste_deploy_app(paste_file, 'staccato-api', conf)
server = os_wsgi.Service(wsgi_app, conf.bind_port)
server.start()
server.wait()
except RuntimeError as e:
fail(1, e)

View File

@ -1,9 +1,20 @@
import logging
from oslo.config import cfg
import json
import logging
from oslo.config import cfg
from staccato.version import version_info as version
paste_deploy_opts = [
cfg.StrOpt('flavor',
help=_('Partial name of a pipeline in your paste configuration '
'file with the service name removed. For example, if '
'your paste section name is '
'[pipeline:glance-api-keystone] use the value '
'"keystone"')),
cfg.StrOpt('config_file',
help=_('Name of the paste configuration file.')),
]
common_opts = [
cfg.ListOpt('protocol_plugins',
default=['staccato.protocols.file.FileProtocol',
@ -36,6 +47,15 @@ common_opts = [
dest='str_log_level'),
cfg.StrOpt('protocol_policy', default='staccato-protocols.json',
help=''),
cfg.StrOpt('service_id', default='staccato1234',
help=''),
]
bind_opts = [
cfg.StrOpt('bind_host', default='0.0.0.0',
help=_('Address to bind the server. Useful when '
'selecting a particular network interface.')),
cfg.IntOpt('bind_port',
help=_('The port on which the server will listen.')),
]
@ -55,6 +75,8 @@ def _log_string_to_val(conf):
def get_config_object(args=None, usage=None, default_config_files=None):
conf = cfg.ConfigOpts()
conf.register_opts(common_opts)
conf.register_opts(bind_opts)
conf.register_opts(paste_deploy_opts, group='paste_deploy')
conf(args=args,
project='staccato',
version=version.cached_version_string(),
@ -71,4 +93,4 @@ def get_protocol_policy(conf):
# TODO log a warning
return {}
policy = json.load(open(protocol_conf_file, 'r'))
return policy
return policy

View File

@ -34,3 +34,7 @@ class StaccatoDataBaseException(StaccatoBaseException):
class StaccatoEventException(StaccatoBaseException):
pass
class StaccatoInvalidStateTransitionException(StaccatoEventException):
pass

View File

@ -34,11 +34,11 @@ class StateMachine(object):
current_state = self._get_current_state(**kwvals)
if current_state not in self._transitions:
raise exceptions.StaccatoParameterError(
raise exceptions.StaccatoInvalidStateTransitionException(
"Undefined event %s at state %s" % (event, current_state))
state_ent = self._transitions[current_state]
if event not in state_ent:
raise exceptions.StaccatoParameterError(
raise exceptions.StaccatoInvalidStateTransitionException(
"Undefined event %s at state %s" % (event, current_state))
next_state, function = state_ent[event]

View File

@ -1,6 +1,8 @@
import logging
from paste import deploy
import re
from paste import deploy
from staccato.common import exceptions
from staccato.openstack.common import importutils

View File

@ -1,23 +0,0 @@
import webob
class Request(webob.Request):
"""Add some Openstack API-specific logic to the base webob.Request."""
def best_match_content_type(self):
"""Determine the requested response content-type."""
supported = ('application/json',)
bm = self.accept.best_match(supported)
return bm or 'application/json'
def get_content_type(self, allowed_content_types):
"""Determine content type of the request body."""
if "Content-Type" not in self.headers:
raise exception.InvalidContentType(content_type=None)
content_type = self.content_type
if content_type not in allowed_content_types:
raise exception.InvalidContentType(content_type=content_type)
else:
return content_type

View File

@ -0,0 +1,142 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2011 OpenStack Foundation.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Exceptions common to OpenStack projects
"""
import logging
from staccato.openstack.common.gettextutils import _
_FATAL_EXCEPTION_FORMAT_ERRORS = False
class Error(Exception):
def __init__(self, message=None):
super(Error, self).__init__(message)
class ApiError(Error):
def __init__(self, message='Unknown', code='Unknown'):
self.message = message
self.code = code
super(ApiError, self).__init__('%s: %s' % (code, message))
class NotFound(Error):
pass
class UnknownScheme(Error):
msg = "Unknown scheme '%s' found in URI"
def __init__(self, scheme):
msg = self.__class__.msg % scheme
super(UnknownScheme, self).__init__(msg)
class BadStoreUri(Error):
msg = "The Store URI %s was malformed. Reason: %s"
def __init__(self, uri, reason):
msg = self.__class__.msg % (uri, reason)
super(BadStoreUri, self).__init__(msg)
class Duplicate(Error):
pass
class NotAuthorized(Error):
pass
class NotEmpty(Error):
pass
class Invalid(Error):
pass
class BadInputError(Exception):
"""Error resulting from a client sending bad input to a server"""
pass
class MissingArgumentError(Error):
pass
class DatabaseMigrationError(Error):
pass
class ClientConnectionError(Exception):
"""Error resulting from a client connecting to a server"""
pass
def wrap_exception(f):
def _wrap(*args, **kw):
try:
return f(*args, **kw)
except Exception as e:
if not isinstance(e, Error):
#exc_type, exc_value, exc_traceback = sys.exc_info()
logging.exception(_('Uncaught exception'))
#logging.error(traceback.extract_stack(exc_traceback))
raise Error(str(e))
raise
_wrap.func_name = f.func_name
return _wrap
class OpenstackException(Exception):
"""
Base Exception
To correctly use this class, inherit from it and define
a 'message' property. That message will get printf'd
with the keyword arguments provided to the constructor.
"""
message = "An unknown exception occurred"
def __init__(self, **kwargs):
try:
self._error_string = self.message % kwargs
except Exception as e:
if _FATAL_EXCEPTION_FORMAT_ERRORS:
raise e
else:
# at least get the core message out if something happened
self._error_string = self.message
def __str__(self):
return self._error_string
class MalformedRequestBody(OpenstackException):
message = "Malformed message body: %(reason)s"
class InvalidContentType(OpenstackException):
message = "Invalid content type %(content_type)s"

View File

@ -43,9 +43,9 @@ import traceback
from oslo.config import cfg
from staccato.openstack.common.gettextutils import _
from staccato.openstack.common import importutils
from staccato.openstack.common import jsonutils
from staccato.openstack.common import local
from staccato.openstack.common import notifier
_DEFAULT_LOG_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
@ -322,17 +322,6 @@ class JSONFormatter(logging.Formatter):
return jsonutils.dumps(message)
class PublishErrorsHandler(logging.Handler):
def emit(self, record):
if ('staccato.openstack.common.notifier.log_notifier' in
CONF.notification_driver):
return
notifier.api.notify(None, 'error.publisher',
'error_notification',
notifier.api.ERROR,
dict(error=record.msg))
def _create_logging_excepthook(product_name):
def logging_excepthook(type, value, tb):
extra = {}
@ -428,7 +417,10 @@ def _setup_logging_from_conf():
log_root.addHandler(streamlog)
if CONF.publish_errors:
log_root.addHandler(PublishErrorsHandler(logging.ERROR))
handler = importutils.import_object(
"staccato.openstack.common.log_handler.PublishErrorsHandler",
logging.ERROR)
log_root.addHandler(handler)
datefmt = CONF.log_date_format
for handler in log_root.handlers:

View File

@ -0,0 +1,64 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2011 OpenStack Foundation.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Middleware that attaches a context to the WSGI request
"""
from staccato.openstack.common import context
from staccato.openstack.common import importutils
from staccato.openstack.common import wsgi
class ContextMiddleware(wsgi.Middleware):
def __init__(self, app, options):
self.options = options
super(ContextMiddleware, self).__init__(app)
def make_context(self, *args, **kwargs):
"""
Create a context with the given arguments.
"""
# Determine the context class to use
ctxcls = context.RequestContext
if 'context_class' in self.options:
ctxcls = importutils.import_class(self.options['context_class'])
return ctxcls(*args, **kwargs)
def process_request(self, req):
"""
Extract any authentication information in the request and
construct an appropriate context from it.
"""
# Use the default empty context, with admin turned on for
# backwards compatibility
req.context = self.make_context(is_admin=True)
def filter_factory(global_conf, **local_conf):
"""
Factory method for paste.deploy
"""
conf = global_conf.copy()
conf.update(local_conf)
def filter(app):
return ContextMiddleware(app, conf)
return filter

View File

@ -0,0 +1,84 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (c) 2012 Red Hat, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Request Body limiting middleware.
"""
from oslo.config import cfg
import webob.dec
import webob.exc
from staccato.openstack.common.gettextutils import _
from staccato.openstack.common import wsgi
#default request size is 112k
max_req_body_size = cfg.IntOpt('max_request_body_size',
deprecated_name='osapi_max_request_body_size',
default=114688,
help='the maximum body size '
'per each request(bytes)')
CONF = cfg.CONF
CONF.register_opt(max_req_body_size)
class LimitingReader(object):
"""Reader to limit the size of an incoming request."""
def __init__(self, data, limit):
"""
:param data: Underlying data object
:param limit: maximum number of bytes the reader should allow
"""
self.data = data
self.limit = limit
self.bytes_read = 0
def __iter__(self):
for chunk in self.data:
self.bytes_read += len(chunk)
if self.bytes_read > self.limit:
msg = _("Request is too large.")
raise webob.exc.HTTPRequestEntityTooLarge(explanation=msg)
else:
yield chunk
def read(self, i=None):
result = self.data.read(i)
self.bytes_read += len(result)
if self.bytes_read > self.limit:
msg = _("Request is too large.")
raise webob.exc.HTTPRequestEntityTooLarge(explanation=msg)
return result
class RequestBodySizeLimiter(wsgi.Middleware):
"""Limit the size of incoming requests."""
def __init__(self, *args, **kwargs):
super(RequestBodySizeLimiter, self).__init__(*args, **kwargs)
@webob.dec.wsgify(RequestClass=wsgi.Request)
def __call__(self, req):
if req.content_length > CONF.max_request_body_size:
msg = _("Request is too large.")
raise webob.exc.HTTPRequestEntityTooLarge(explanation=msg)
if req.content_length is None and req.is_body_readable:
limiter = LimitingReader(req.body_file,
CONF.max_request_body_size)
req.body_file = limiter
return self.application

View File

@ -0,0 +1,164 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2012 Red Hat, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import sys
from paste import deploy
from staccato.openstack.common import local
class BasePasteFactory(object):
"""A base class for paste app and filter factories.
Sub-classes must override the KEY class attribute and provide
a __call__ method.
"""
KEY = None
def __init__(self, data):
self.data = data
def _import_factory(self, local_conf):
"""Import an app/filter class.
Lookup the KEY from the PasteDeploy local conf and import the
class named there. This class can then be used as an app or
filter factory.
Note we support the <module>:<class> format.
Note also that if you do e.g.
key =
value
then ConfigParser returns a value with a leading newline, so
we strip() the value before using it.
"""
mod_str, _sep, class_str = local_conf[self.KEY].strip().rpartition(':')
del local_conf[self.KEY]
__import__(mod_str)
return getattr(sys.modules[mod_str], class_str)
class AppFactory(BasePasteFactory):
"""A Generic paste.deploy app factory.
This requires openstack.app_factory to be set to a callable which returns a
WSGI app when invoked. The format of the name is <module>:<callable> e.g.
[app:myfooapp]
paste.app_factory = openstack.common.pastedeploy:app_factory
openstack.app_factory = myapp:Foo
The WSGI app constructor must accept a data object and a local config
dict as its two arguments.
"""
KEY = 'openstack.app_factory'
def __call__(self, global_conf, **local_conf):
"""The actual paste.app_factory protocol method."""
factory = self._import_factory(local_conf)
return factory(self.data, **local_conf)
class FilterFactory(AppFactory):
"""A Generic paste.deploy filter factory.
This requires openstack.filter_factory to be set to a callable which
returns a WSGI filter when invoked. The format is <module>:<callable> e.g.
[filter:myfoofilter]
paste.filter_factory = openstack.common.pastedeploy:filter_factory
openstack.filter_factory = myfilter:Foo
The WSGI filter constructor must accept a WSGI app, a data object and
a local config dict as its three arguments.
"""
KEY = 'openstack.filter_factory'
def __call__(self, global_conf, **local_conf):
"""The actual paste.filter_factory protocol method."""
factory = self._import_factory(local_conf)
def filter(app):
return factory(app, self.data, **local_conf)
return filter
def app_factory(global_conf, **local_conf):
"""A paste app factory used with paste_deploy_app()."""
return local.store.app_factory(global_conf, **local_conf)
def filter_factory(global_conf, **local_conf):
"""A paste filter factory used with paste_deploy_app()."""
return local.store.filter_factory(global_conf, **local_conf)
def paste_deploy_app(paste_config_file, app_name, data):
"""Load a WSGI app from a PasteDeploy configuration.
Use deploy.loadapp() to load the app from the PasteDeploy configuration,
ensuring that the supplied data object is passed to the app and filter
factories defined in this module.
To use these factories and the data object, the configuration should look
like this:
[app:myapp]
paste.app_factory = openstack.common.pastedeploy:app_factory
openstack.app_factory = myapp:App
...
[filter:myfilter]
paste.filter_factory = openstack.common.pastedeploy:filter_factory
openstack.filter_factory = myapp:Filter
and then:
myapp.py:
class App(object):
def __init__(self, data):
...
class Filter(object):
def __init__(self, app, data):
...
:param paste_config_file: a PasteDeploy config file
:param app_name: the name of the app/pipeline to load from the file
:param data: a data object to supply to the app and its filters
:returns: the WSGI app
"""
(af, ff) = (AppFactory(data), FilterFactory(data))
local.store.app_factory = af
local.store.filter_factory = ff
try:
return deploy.loadapp("config:%s" % paste_config_file, name=app_name)
finally:
del local.store.app_factory
del local.store.filter_factory

View File

@ -437,7 +437,7 @@ def _parse_list_rule(rule):
or_list.append(AndCheck(and_list))
# If we have only one check, omit the "or"
if len(or_list) == 0:
if not or_list:
return FalseCheck()
elif len(or_list) == 1:
return or_list[0]

View File

@ -123,7 +123,7 @@ def execute(*cmd, **kwargs):
elif isinstance(check_exit_code, int):
check_exit_code = [check_exit_code]
if len(kwargs):
if kwargs:
raise UnknownArgumentError(_('Got unknown keyword args '
'to utils.execute: %r') % kwargs)

View File

@ -158,6 +158,10 @@ class UnsupportedRpcEnvelopeVersion(RPCException):
"not supported by this endpoint.")
class RpcVersionCapError(RPCException):
message = _("Specified RPC version cap, %(version_cap)s, is too low")
class Connection(object):
"""A connection, returned by rpc.create_connection().

View File

@ -375,7 +375,7 @@ class Connection(object):
try:
return method(*args, **kwargs)
except (qpid_exceptions.Empty,
qpid_exceptions.ConnectionError), e:
qpid_exceptions.ConnectionError) as e:
if error_callback:
error_callback(e)
self.reconnect()

View File

@ -180,7 +180,7 @@ class ZmqSocket(object):
return
# We must unsubscribe, or we'll leak descriptors.
if len(self.subscriptions) > 0:
if self.subscriptions:
for f in self.subscriptions:
try:
self.sock.setsockopt(zmq.UNSUBSCRIBE, f)
@ -763,7 +763,7 @@ def _multi_send(method, context, topic, msg, timeout=None,
LOG.debug(_("Sending message(s) to: %s"), queues)
# Don't stack if we have no matchmaker results
if len(queues) == 0:
if not queues:
LOG.warn(_("No matchmaker results. Not casting."))
# While not strictly a timeout, callers know how to handle
# this exception and a timeout isn't too big a lie.

View File

@ -245,7 +245,7 @@ class HeartbeatMatchMakerBase(MatchMakerBase):
yielding for CONF.matchmaker_heartbeat_freq seconds
between iterations.
"""
if len(self.hosts) == 0:
if not self.hosts:
raise MatchMakerException(
_("Register before starting heartbeat."))

View File

@ -1,6 +1,6 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2012 Red Hat, Inc.
# Copyright 2012-2013 Red Hat, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
@ -23,6 +23,7 @@ For more information about rpc API version numbers, see:
from staccato.openstack.common import rpc
from staccato.openstack.common.rpc import common as rpc_common
class RpcProxy(object):
@ -34,16 +35,19 @@ class RpcProxy(object):
rpc API.
"""
def __init__(self, topic, default_version):
def __init__(self, topic, default_version, version_cap=None):
"""Initialize an RpcProxy.
:param topic: The topic to use for all messages.
:param default_version: The default API version to request in all
outgoing messages. This can be overridden on a per-message
basis.
:param version_cap: Optionally cap the maximum version used for sent
messages.
"""
self.topic = topic
self.default_version = default_version
self.version_cap = version_cap
super(RpcProxy, self).__init__()
def _set_version(self, msg, vers):
@ -52,7 +56,11 @@ class RpcProxy(object):
:param msg: The message having a version added to it.
:param vers: The version number to add to the message.
"""
msg['version'] = vers if vers else self.default_version
v = vers if vers else self.default_version
if (self.version_cap and not
rpc_common.version_is_compatible(self.version_cap, v)):
raise rpc_common.RpcVersionCapError(version=self.version_cap)
msg['version'] = v
def _get_topic(self, topic):
"""Return the topic to use for a message."""

View File

@ -52,7 +52,7 @@ class Launcher(object):
"""
self._services = threadgroup.ThreadGroup()
eventlet_backdoor.initialize_if_enabled()
self.backdoor_port = eventlet_backdoor.initialize_if_enabled()
@staticmethod
def run_service(service):
@ -72,6 +72,7 @@ class Launcher(object):
:returns: None
"""
service.backdoor_port = self.backdoor_port
self._services.add_thread(self.run_service, service)
def stop(self):

View File

@ -127,11 +127,9 @@ def _run_shell_command(cmd, throw_on_error=False):
out = output.communicate()
if output.returncode and throw_on_error:
raise Exception("%s returned %d" % cmd, output.returncode)
if len(out) == 0:
if not out:
return None
if len(out[0].strip()) == 0:
return None
return out[0].strip()
return out[0].strip() or None
def _get_git_directory():

View File

@ -0,0 +1,80 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2013 IBM Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import os
import ssl
from oslo.config import cfg
from staccato.openstack.common.gettextutils import _
ssl_opts = [
cfg.StrOpt('ca_file',
default=None,
help="CA certificate file to use to verify "
"connecting clients"),
cfg.StrOpt('cert_file',
default=None,
help="Certificate file to use when starting "
"the server securely"),
cfg.StrOpt('key_file',
default=None,
help="Private key file to use when starting "
"the server securely"),
]
CONF = cfg.CONF
CONF.register_opts(ssl_opts, "ssl")
def is_enabled():
cert_file = CONF.ssl.cert_file
key_file = CONF.ssl.key_file
ca_file = CONF.ssl.ca_file
use_ssl = cert_file or key_file
if cert_file and not os.path.exists(cert_file):
raise RuntimeError(_("Unable to find cert_file : %s") % cert_file)
if ca_file and not os.path.exists(ca_file):
raise RuntimeError(_("Unable to find ca_file : %s") % ca_file)
if key_file and not os.path.exists(key_file):
raise RuntimeError(_("Unable to find key_file : %s") % key_file)
if use_ssl and (not cert_file or not key_file):
raise RuntimeError(_("When running server in SSL mode, you must "
"specify both a cert_file and key_file "
"option value in your configuration file"))
return use_ssl
def wrap(sock):
ssl_kwargs = {
'server_side': True,
'certfile': CONF.ssl.cert_file,
'keyfile': CONF.ssl.key_file,
'cert_reqs': ssl.CERT_NONE,
}
if CONF.ssl.ca_file:
ssl_kwargs['ca_certs'] = CONF.ssl.ca_file
ssl_kwargs['cert_reqs'] = ssl.CERT_REQUIRED
return ssl.wrap_socket(sock, **ssl_kwargs)

View File

@ -61,6 +61,13 @@ class ThreadGroup(object):
self.threads = []
self.timers = []
def add_dynamic_timer(self, callback, initial_delay=None,
periodic_interval_max=None, *args, **kwargs):
timer = loopingcall.DynamicLoopingCall(callback, *args, **kwargs)
timer.start(initial_delay=initial_delay,
periodic_interval_max=periodic_interval_max)
self.timers.append(timer)
def add_timer(self, interval, callback, initial_delay=None,
*args, **kwargs):
pulse = loopingcall.FixedIntervalLoopingCall(callback, *args, **kwargs)

View File

@ -0,0 +1,800 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2011 OpenStack Foundation.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Utility methods for working with WSGI servers."""
from __future__ import print_function
import eventlet
eventlet.patcher.monkey_patch(all=False, socket=True)
import datetime
import errno
import socket
import sys
import time
import eventlet.wsgi
from oslo.config import cfg
import routes
import routes.middleware
import six
import webob.dec
import webob.exc
from xml.dom import minidom
from xml.parsers import expat
from staccato.openstack.common import exception
from staccato.openstack.common.gettextutils import _
from staccato.openstack.common import jsonutils
from staccato.openstack.common import log as logging
from staccato.openstack.common import service
from staccato.openstack.common import sslutils
from staccato.openstack.common import xmlutils
socket_opts = [
cfg.IntOpt('backlog',
default=4096,
help="Number of backlog requests to configure the socket with"),
cfg.IntOpt('tcp_keepidle',
default=600,
help="Sets the value of TCP_KEEPIDLE in seconds for each "
"server socket. Not supported on OS X."),
]
CONF = cfg.CONF
CONF.register_opts(socket_opts)
LOG = logging.getLogger(__name__)
def run_server(application, port, **kwargs):
"""Run a WSGI server with the given application."""
sock = eventlet.listen(('0.0.0.0', port))
eventlet.wsgi.server(sock, application, **kwargs)
class Service(service.Service):
"""
Provides a Service API for wsgi servers.
This gives us the ability to launch wsgi servers with the
Launcher classes in service.py.
"""
def __init__(self, application, port,
host='0.0.0.0', backlog=4096, threads=1000):
self.application = application
self._port = port
self._host = host
self._backlog = backlog if backlog else CONF.backlog
self._socket = self._get_socket(host, port, self._backlog)
super(Service, self).__init__(threads)
def _get_socket(self, host, port, backlog):
# TODO(dims): eventlet's green dns/socket module does not actually
# support IPv6 in getaddrinfo(). We need to get around this in the
# future or monitor upstream for a fix
info = socket.getaddrinfo(host,
port,
socket.AF_UNSPEC,
socket.SOCK_STREAM)[0]
family = info[0]
bind_addr = info[-1]
sock = None
retry_until = time.time() + 30
while not sock and time.time() < retry_until:
try:
sock = eventlet.listen(bind_addr,
backlog=backlog,
family=family)
if sslutils.is_enabled():
sock = sslutils.wrap(sock)
except socket.error as err:
if err.args[0] != errno.EADDRINUSE:
raise
eventlet.sleep(0.1)
if not sock:
raise RuntimeError(_("Could not bind to %(host)s:%(port)s "
"after trying for 30 seconds") %
{'host': host, 'port': port})
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# sockets can hang around forever without keepalive
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
# This option isn't available in the OS X version of eventlet
if hasattr(socket, 'TCP_KEEPIDLE'):
sock.setsockopt(socket.IPPROTO_TCP,
socket.TCP_KEEPIDLE,
CONF.tcp_keepidle)
return sock
def start(self):
"""Start serving this service using the provided server instance.
:returns: None
"""
super(Service, self).start()
self.tg.add_thread(self._run, self.application, self._socket)
@property
def backlog(self):
return self._backlog
@property
def host(self):
return self._socket.getsockname()[0] if self._socket else self._host
@property
def port(self):
return self._socket.getsockname()[1] if self._socket else self._port
def stop(self):
"""Stop serving this API.
:returns: None
"""
super(Service, self).stop()
def _run(self, application, socket):
"""Start a WSGI server in a new green thread."""
logger = logging.getLogger('eventlet.wsgi')
eventlet.wsgi.server(socket,
application,
custom_pool=self.tg.pool,
log=logging.WritableLogger(logger))
class Middleware(object):
"""
Base WSGI middleware wrapper. These classes require an application to be
initialized that will be called next. By default the middleware will
simply call its wrapped app, or you can override __call__ to customize its
behavior.
"""
def __init__(self, application):
self.application = application
def process_request(self, req):
"""
Called on each request.
If this returns None, the next application down the stack will be
executed. If it returns a response then that response will be returned
and execution will stop here.
"""
return None
def process_response(self, response):
"""Do whatever you'd like to the response."""
return response
@webob.dec.wsgify
def __call__(self, req):
response = self.process_request(req)
if response:
return response
response = req.get_response(self.application)
return self.process_response(response)
class Debug(Middleware):
"""
Helper class that can be inserted into any WSGI application chain
to get information about the request and response.
"""
@webob.dec.wsgify
def __call__(self, req):
print(("*" * 40) + " REQUEST ENVIRON")
for key, value in req.environ.items():
print(key, "=", value)
print()
resp = req.get_response(self.application)
print(("*" * 40) + " RESPONSE HEADERS")
for (key, value) in resp.headers.iteritems():
print(key, "=", value)
print()
resp.app_iter = self.print_generator(resp.app_iter)
return resp
@staticmethod
def print_generator(app_iter):
"""
Iterator that prints the contents of a wrapper string iterator
when iterated.
"""
print(("*" * 40) + " BODY")
for part in app_iter:
sys.stdout.write(part)
sys.stdout.flush()
yield part
print()
class Router(object):
"""
WSGI middleware that maps incoming requests to WSGI apps.
"""
def __init__(self, mapper):
"""
Create a router for the given routes.Mapper.
Each route in `mapper` must specify a 'controller', which is a
WSGI app to call. You'll probably want to specify an 'action' as
well and have your controller be a wsgi.Controller, who will route
the request to the action method.
Examples:
mapper = routes.Mapper()
sc = ServerController()
# Explicit mapping of one route to a controller+action
mapper.connect(None, "/svrlist", controller=sc, action="list")
# Actions are all implicitly defined
mapper.resource("server", "servers", controller=sc)
# Pointing to an arbitrary WSGI app. You can specify the
# {path_info:.*} parameter so the target app can be handed just that
# section of the URL.
mapper.connect(None, "/v1.0/{path_info:.*}", controller=BlogApp())
"""
self.map = mapper
self._router = routes.middleware.RoutesMiddleware(self._dispatch,
self.map)
@webob.dec.wsgify
def __call__(self, req):
"""
Route the incoming request to a controller based on self.map.
If no match, return a 404.
"""
return self._router
@staticmethod
@webob.dec.wsgify
def _dispatch(req):
"""
Called by self._router after matching the incoming request to a route
and putting the information into req.environ. Either returns 404
or the routed WSGI app's response.
"""
match = req.environ['wsgiorg.routing_args'][1]
if not match:
return webob.exc.HTTPNotFound()
app = match['controller']
return app
class Request(webob.Request):
"""Add some Openstack API-specific logic to the base webob.Request."""
default_request_content_types = ('application/json', 'application/xml')
default_accept_types = ('application/json', 'application/xml')
default_accept_type = 'application/json'
def best_match_content_type(self, supported_content_types=None):
"""Determine the requested response content-type.
Based on the query extension then the Accept header.
Defaults to default_accept_type if we don't find a preference
"""
supported_content_types = (supported_content_types or
self.default_accept_types)
parts = self.path.rsplit('.', 1)
if len(parts) > 1:
ctype = 'application/{0}'.format(parts[1])
if ctype in supported_content_types:
return ctype
bm = self.accept.best_match(supported_content_types)
return bm or self.default_accept_type
def get_content_type(self, allowed_content_types=None):
"""Determine content type of the request body.
Does not do any body introspection, only checks header
"""
if "Content-Type" not in self.headers:
return None
content_type = self.content_type
allowed_content_types = (allowed_content_types or
self.default_request_content_types)
if content_type not in allowed_content_types:
raise exception.InvalidContentType(content_type=content_type)
return content_type
class Resource(object):
"""
WSGI app that handles (de)serialization and controller dispatch.
Reads routing information supplied by RoutesMiddleware and calls
the requested action method upon its deserializer, controller,
and serializer. Those three objects may implement any of the basic
controller action methods (create, update, show, index, delete)
along with any that may be specified in the api router. A 'default'
method may also be implemented to be used in place of any
non-implemented actions. Deserializer methods must accept a request
argument and return a dictionary. Controller methods must accept a
request argument. Additionally, they must also accept keyword
arguments that represent the keys returned by the Deserializer. They
may raise a webob.exc exception or return a dict, which will be
serialized by requested content type.
"""
def __init__(self, controller, deserializer=None, serializer=None):
"""
:param controller: object that implement methods created by routes lib
:param deserializer: object that supports webob request deserialization
through controller-like actions
:param serializer: object that supports webob response serialization
through controller-like actions
"""
self.controller = controller
self.serializer = serializer or ResponseSerializer()
self.deserializer = deserializer or RequestDeserializer()
@webob.dec.wsgify(RequestClass=Request)
def __call__(self, request):
"""WSGI method that controls (de)serialization and method dispatch."""
try:
action, action_args, accept = self.deserialize_request(request)
except exception.InvalidContentType:
msg = _("Unsupported Content-Type")
return webob.exc.HTTPUnsupportedMediaType(explanation=msg)
except exception.MalformedRequestBody:
msg = _("Malformed request body")
return webob.exc.HTTPBadRequest(explanation=msg)
action_result = self.execute_action(action, request, **action_args)
try:
return self.serialize_response(action, action_result, accept)
# return unserializable result (typically a webob exc)
except Exception:
return action_result
def deserialize_request(self, request):
return self.deserializer.deserialize(request)
def serialize_response(self, action, action_result, accept):
return self.serializer.serialize(action_result, accept, action)
def execute_action(self, action, request, **action_args):
return self.dispatch(self.controller, action, request, **action_args)
def dispatch(self, obj, action, *args, **kwargs):
"""Find action-specific method on self and call it."""
try:
method = getattr(obj, action)
except AttributeError:
method = getattr(obj, 'default')
return method(*args, **kwargs)
def get_action_args(self, request_environment):
"""Parse dictionary created by routes library."""
try:
args = request_environment['wsgiorg.routing_args'][1].copy()
except Exception:
return {}
try:
del args['controller']
except KeyError:
pass
try:
del args['format']
except KeyError:
pass
return args
class ActionDispatcher(object):
"""Maps method name to local methods through action name."""
def dispatch(self, *args, **kwargs):
"""Find and call local method."""
action = kwargs.pop('action', 'default')
action_method = getattr(self, str(action), self.default)
return action_method(*args, **kwargs)
def default(self, data):
raise NotImplementedError()
class DictSerializer(ActionDispatcher):
"""Default request body serialization"""
def serialize(self, data, action='default'):
return self.dispatch(data, action=action)
def default(self, data):
return ""
class JSONDictSerializer(DictSerializer):
"""Default JSON request body serialization"""
def default(self, data):
def sanitizer(obj):
if isinstance(obj, datetime.datetime):
_dtime = obj - datetime.timedelta(microseconds=obj.microsecond)
return _dtime.isoformat()
return six.text_type(obj)
return jsonutils.dumps(data, default=sanitizer)
class XMLDictSerializer(DictSerializer):
def __init__(self, metadata=None, xmlns=None):
"""
:param metadata: information needed to deserialize xml into
a dictionary.
:param xmlns: XML namespace to include with serialized xml
"""
super(XMLDictSerializer, self).__init__()
self.metadata = metadata or {}
self.xmlns = xmlns
def default(self, data):
# We expect data to contain a single key which is the XML root.
root_key = data.keys()[0]
doc = minidom.Document()
node = self._to_xml_node(doc, self.metadata, root_key, data[root_key])
return self.to_xml_string(node)
def to_xml_string(self, node, has_atom=False):
self._add_xmlns(node, has_atom)
return node.toprettyxml(indent=' ', encoding='UTF-8')
#NOTE (ameade): the has_atom should be removed after all of the
# xml serializers and view builders have been updated to the current
# spec that required all responses include the xmlns:atom, the has_atom
# flag is to prevent current tests from breaking
def _add_xmlns(self, node, has_atom=False):
if self.xmlns is not None:
node.setAttribute('xmlns', self.xmlns)
if has_atom:
node.setAttribute('xmlns:atom', "http://www.w3.org/2005/Atom")
def _to_xml_node(self, doc, metadata, nodename, data):
"""Recursive method to convert data members to XML nodes."""
result = doc.createElement(nodename)
# Set the xml namespace if one is specified
# TODO(justinsb): We could also use prefixes on the keys
xmlns = metadata.get('xmlns', None)
if xmlns:
result.setAttribute('xmlns', xmlns)
#TODO(bcwaldon): accomplish this without a type-check
if type(data) is list:
collections = metadata.get('list_collections', {})
if nodename in collections:
metadata = collections[nodename]
for item in data:
node = doc.createElement(metadata['item_name'])
node.setAttribute(metadata['item_key'], str(item))
result.appendChild(node)
return result
singular = metadata.get('plurals', {}).get(nodename, None)
if singular is None:
if nodename.endswith('s'):
singular = nodename[:-1]
else:
singular = 'item'
for item in data:
node = self._to_xml_node(doc, metadata, singular, item)
result.appendChild(node)
#TODO(bcwaldon): accomplish this without a type-check
elif type(data) is dict:
collections = metadata.get('dict_collections', {})
if nodename in collections:
metadata = collections[nodename]
for k, v in data.items():
node = doc.createElement(metadata['item_name'])
node.setAttribute(metadata['item_key'], str(k))
text = doc.createTextNode(str(v))
node.appendChild(text)
result.appendChild(node)
return result
attrs = metadata.get('attributes', {}).get(nodename, {})
for k, v in data.items():
if k in attrs:
result.setAttribute(k, str(v))
else:
node = self._to_xml_node(doc, metadata, k, v)
result.appendChild(node)
else:
# Type is atom
node = doc.createTextNode(str(data))
result.appendChild(node)
return result
def _create_link_nodes(self, xml_doc, links):
link_nodes = []
for link in links:
link_node = xml_doc.createElement('atom:link')
link_node.setAttribute('rel', link['rel'])
link_node.setAttribute('href', link['href'])
if 'type' in link:
link_node.setAttribute('type', link['type'])
link_nodes.append(link_node)
return link_nodes
class ResponseHeadersSerializer(ActionDispatcher):
"""Default response headers serialization"""
def serialize(self, response, data, action):
self.dispatch(response, data, action=action)
def default(self, response, data):
response.status_int = 200
class ResponseSerializer(object):
"""Encode the necessary pieces into a response object"""
def __init__(self, body_serializers=None, headers_serializer=None):
self.body_serializers = {
'application/xml': XMLDictSerializer(),
'application/json': JSONDictSerializer(),
}
self.body_serializers.update(body_serializers or {})
self.headers_serializer = (headers_serializer or
ResponseHeadersSerializer())
def serialize(self, response_data, content_type, action='default'):
"""Serialize a dict into a string and wrap in a wsgi.Request object.
:param response_data: dict produced by the Controller
:param content_type: expected mimetype of serialized response body
"""
response = webob.Response()
self.serialize_headers(response, response_data, action)
self.serialize_body(response, response_data, content_type, action)
return response
def serialize_headers(self, response, data, action):
self.headers_serializer.serialize(response, data, action)
def serialize_body(self, response, data, content_type, action):
response.headers['Content-Type'] = content_type
if data is not None:
serializer = self.get_body_serializer(content_type)
response.body = serializer.serialize(data, action)
def get_body_serializer(self, content_type):
try:
return self.body_serializers[content_type]
except (KeyError, TypeError):
raise exception.InvalidContentType(content_type=content_type)
class RequestHeadersDeserializer(ActionDispatcher):
"""Default request headers deserializer"""
def deserialize(self, request, action):
return self.dispatch(request, action=action)
def default(self, request):
return {}
class RequestDeserializer(object):
"""Break up a Request object into more useful pieces."""
def __init__(self, body_deserializers=None, headers_deserializer=None,
supported_content_types=None):
self.supported_content_types = supported_content_types
self.body_deserializers = {
'application/xml': XMLDeserializer(),
'application/json': JSONDeserializer(),
}
self.body_deserializers.update(body_deserializers or {})
self.headers_deserializer = (headers_deserializer or
RequestHeadersDeserializer())
def deserialize(self, request):
"""Extract necessary pieces of the request.
:param request: Request object
:returns: tuple of (expected controller action name, dictionary of
keyword arguments to pass to the controller, the expected
content type of the response)
"""
action_args = self.get_action_args(request.environ)
action = action_args.pop('action', None)
action_args.update(self.deserialize_headers(request, action))
action_args.update(self.deserialize_body(request, action))
accept = self.get_expected_content_type(request)
return (action, action_args, accept)
def deserialize_headers(self, request, action):
return self.headers_deserializer.deserialize(request, action)
def deserialize_body(self, request, action):
if not request.body:
LOG.debug(_("Empty body provided in request"))
return {}
try:
content_type = request.get_content_type()
except exception.InvalidContentType:
LOG.debug(_("Unrecognized Content-Type provided in request"))
raise
if content_type is None:
LOG.debug(_("No Content-Type provided in request"))
return {}
try:
deserializer = self.get_body_deserializer(content_type)
except exception.InvalidContentType:
LOG.debug(_("Unable to deserialize body as provided Content-Type"))
raise
return deserializer.deserialize(request.body, action)
def get_body_deserializer(self, content_type):
try:
return self.body_deserializers[content_type]
except (KeyError, TypeError):
raise exception.InvalidContentType(content_type=content_type)
def get_expected_content_type(self, request):
return request.best_match_content_type(self.supported_content_types)
def get_action_args(self, request_environment):
"""Parse dictionary created by routes library."""
try:
args = request_environment['wsgiorg.routing_args'][1].copy()
except Exception:
return {}
try:
del args['controller']
except KeyError:
pass
try:
del args['format']
except KeyError:
pass
return args
class TextDeserializer(ActionDispatcher):
"""Default request body deserialization"""
def deserialize(self, datastring, action='default'):
return self.dispatch(datastring, action=action)
def default(self, datastring):
return {}
class JSONDeserializer(TextDeserializer):
def _from_json(self, datastring):
try:
return jsonutils.loads(datastring)
except ValueError:
msg = _("cannot understand JSON")
raise exception.MalformedRequestBody(reason=msg)
def default(self, datastring):
return {'body': self._from_json(datastring)}
class XMLDeserializer(TextDeserializer):
def __init__(self, metadata=None):
"""
:param metadata: information needed to deserialize xml into
a dictionary.
"""
super(XMLDeserializer, self).__init__()
self.metadata = metadata or {}
def _from_xml(self, datastring):
plurals = set(self.metadata.get('plurals', {}))
try:
node = xmlutils.safe_minidom_parse_string(datastring).childNodes[0]
return {node.nodeName: self._from_xml_node(node, plurals)}
except expat.ExpatError:
msg = _("cannot understand XML")
raise exception.MalformedRequestBody(reason=msg)
def _from_xml_node(self, node, listnames):
"""Convert a minidom node to a simple Python type.
:param listnames: list of XML node names whose subnodes should
be considered list items.
"""
if len(node.childNodes) == 1 and node.childNodes[0].nodeType == 3:
return node.childNodes[0].nodeValue
elif node.nodeName in listnames:
return [self._from_xml_node(n, listnames) for n in node.childNodes]
else:
result = dict()
for attr in node.attributes.keys():
result[attr] = node.attributes[attr].nodeValue
for child in node.childNodes:
if child.nodeType != node.TEXT_NODE:
result[child.nodeName] = self._from_xml_node(child,
listnames)
return result
def find_first_child_named(self, parent, name):
"""Search a nodes children for the first child with a given name"""
for node in parent.childNodes:
if node.nodeName == name:
return node
return None
def find_children_named(self, parent, name):
"""Return all of a nodes children who have the given name"""
for node in parent.childNodes:
if node.nodeName == name:
yield node
def extract_text(self, node):
"""Get the text field contained by the given node"""
if len(node.childNodes) == 1:
child = node.childNodes[0]
if child.nodeType == child.TEXT_NODE:
return child.nodeValue
return ""
def default(self, datastring):
return {'body': self._from_xml(datastring)}

View File

@ -0,0 +1,74 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2013 IBM Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from xml.dom import minidom
from xml.parsers import expat
from xml import sax
from xml.sax import expatreader
class ProtectedExpatParser(expatreader.ExpatParser):
"""An expat parser which disables DTD's and entities by default."""
def __init__(self, forbid_dtd=True, forbid_entities=True,
*args, **kwargs):
# Python 2.x old style class
expatreader.ExpatParser.__init__(self, *args, **kwargs)
self.forbid_dtd = forbid_dtd
self.forbid_entities = forbid_entities
def start_doctype_decl(self, name, sysid, pubid, has_internal_subset):
raise ValueError("Inline DTD forbidden")
def entity_decl(self, entityName, is_parameter_entity, value, base,
systemId, publicId, notationName):
raise ValueError("<!ENTITY> entity declaration forbidden")
def unparsed_entity_decl(self, name, base, sysid, pubid, notation_name):
# expat 1.2
raise ValueError("<!ENTITY> unparsed entity forbidden")
def external_entity_ref(self, context, base, systemId, publicId):
raise ValueError("<!ENTITY> external entity forbidden")
def notation_decl(self, name, base, sysid, pubid):
raise ValueError("<!ENTITY> notation forbidden")
def reset(self):
expatreader.ExpatParser.reset(self)
if self.forbid_dtd:
self._parser.StartDoctypeDeclHandler = self.start_doctype_decl
self._parser.EndDoctypeDeclHandler = None
if self.forbid_entities:
self._parser.EntityDeclHandler = self.entity_decl
self._parser.UnparsedEntityDeclHandler = self.unparsed_entity_decl
self._parser.ExternalEntityRefHandler = self.external_entity_ref
self._parser.NotationDeclHandler = self.notation_decl
try:
self._parser.SkippedEntityHandler = None
except AttributeError:
# some pyexpat versions do not support SkippedEntity
pass
def safe_minidom_parse_string(xml_string):
"""Parse an XML string using minidom safely.
"""
try:
return minidom.parseString(xml_string, parser=ProtectedExpatParser())
except sax.SAXParseException:
raise expat.ExpatError()

View File

@ -1,5 +1,6 @@
import filecmp
import time
import mox
import staccato.xfer.events as xfer_events
import staccato.xfer.interface as xfer_iface
@ -11,10 +12,27 @@ from staccato.tests import utils
import staccato.xfer.executor as executor
class FakeStateMachine(object):
class FakeReadProtocolProcessError(object):
pass
class FileProtocol(base.BaseProtocolInterface):
def __init__(self, service_config):
self.conf = service_config
def new_write(self, dsturl_parts, dst_opts):
return {}
def new_read(self, srcurl_parts, src_opts):
return
def get_reader(self, url_parts, writer, monitor, start=0,
end=None, **kwvals):
def get_writer(self, url_parts, checkpointer, **kwvals):
def event_occurred(self, *args, **kwvals):
pass
class TestXfer(utils.TempFileCleanupBaseTest):
@ -35,9 +53,12 @@ class TestXfer(utils.TempFileCleanupBaseTest):
self.executor = executor.SimpleThreadExecutor(self.conf)
self.sm = xfer_events.XferStateMachine(self.executor)
self.mox = mox.Mox()
def tearDown(self):
self.executor.shutdown()
super(TestXfer, self).tearDown()
self.mox.UnsetStubs()
def test_file_xfer_basic(self):
dst_file = self.get_tempfile()
@ -74,3 +95,28 @@ class TestXfer(utils.TempFileCleanupBaseTest):
xfer = db_obj.lookup_xfer_request_by_id(xfer.id)
self.assertTrue(xfer.state, constants.States.STATE_CANCELED)
def test_file_xfer_error(self):
dst_file = "/dev/null"
src_file = "/dev/zero"
src_url = "file://%s" % src_file
dst_url = "file://%s" % dst_file
z = FakeReadProtocol
fake_protocol_name = "%s.%s" % (z.__module__, z.__name__)
xfer = xfer_iface.xfer_new(self.conf, src_url, dst_url,
{}, {}, 0, None)
db_obj = db.StaccatoDB(self.conf)
xfer_iface.xfer_start(self.conf, xfer.id, self.sm)
self.mox.StubOutWithMock(file, "safe_delete_from_backend")
xfer_iface.xfer_cancel(self.conf, xfer.id, self.sm)
while not xfer_consts.is_state_done_running(xfer.state):
time.sleep(0.1)
xfer = db_obj.lookup_xfer_request_by_id(xfer.id)
self.assertTrue(xfer.state, constants.States.STATE_CANCELED)

View File

@ -207,7 +207,8 @@ class Fedora(Distro):
This can be removed when the fix is applied upstream.
Nova: https://bugs.launchpad.net/nova/+bug/884915
Upstream: https://bitbucket.org/which_linden/eventlet/issue/89
Upstream: https://bitbucket.org/eventlet/eventlet/issue/89
RHEL: https://bugzilla.redhat.com/958868
"""
# Install "patch" program if it's not there