OAuth Storage Abstraction Removed

The DB storage mechanism was doubly abstracted, once via the
token_storage mechanism, and once via the DB API. Since the
token_storage mechanism was never properly instantiated during tests,
and was therefore preventing us from running tests properly,
I've removed it entirely and pulled all methods from db_storage
either directly into the various validators, or into the DB
abstraction.

Change-Id: Ia3026d5fa6bc45269beb6cfbc237fc62e98fb329
This commit is contained in:
Michael Krotscheck 2015-02-04 17:11:15 -08:00
parent 8195049588
commit fbaa9220ce
11 changed files with 82 additions and 455 deletions

View File

@ -21,8 +21,6 @@ from oslo_log import log
import pecan
from wsgiref import simple_server
from storyboard.api.auth.token_storage import impls as storage_impls
from storyboard.api.auth.token_storage import storage
from storyboard.api import config as api_config
from storyboard.api.middleware.cors_middleware import CORSMiddleware
from storyboard.api.middleware import token_middleware
@ -87,11 +85,6 @@ def setup_app(pecan_config=None):
validation_hook.ValidationHook()
]
# Setup token storage
token_storage_type = CONF.token_storage_type
storage_cls = storage_impls.STORAGE_IMPLS[token_storage_type]
storage.set_storage(storage_cls())
# Setup search engine
search_engine_name = CONF.search_engine
search_engine_cls = search_engine_impls.ENGINE_IMPLS[search_engine_name]

View File

@ -4,7 +4,7 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
@ -16,7 +16,7 @@
from pecan import abort
from pecan import request
from storyboard.api.auth.token_storage import storage
from storyboard.db.api import access_tokens as token_api
from storyboard.db.api import users as user_api
from storyboard.openstack.common.gettextutils import _ # noqa
@ -29,7 +29,6 @@ def _get_token():
def guest():
token_storage = storage.get_storage()
token = _get_token()
# Public resources do not require a token.
@ -37,25 +36,23 @@ def guest():
return True
# But if there is a token, it should be valid.
return token_storage.check_access_token(token)
return token_api.is_valid(token)
def authenticated():
token_storage = storage.get_storage()
token = _get_token()
return token and token_storage.check_access_token(token)
return token_api.is_valid(token)
def superuser():
token_storage = storage.get_storage()
token = _get_token()
if not token:
return False
token_info = token_storage.get_access_token_info(token)
user = user_api.user_get(token_info.user_id)
token = token_api.access_token_get_by_token(token)
user = user_api.user_get(token.user_id)
if not user.is_superuser:
abort(403, _("This action is limited to superusers only."))

View File

@ -4,7 +4,7 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
@ -13,14 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import datetime
import datetime
from oauthlib.oauth2 import RequestValidator
from oauthlib.oauth2 import WebApplicationServer
from oslo.config import cfg
from oslo_log import log
from storyboard.api.auth.token_storage import storage
from storyboard.db.api import access_tokens as token_api
from storyboard.db.api import auth as auth_api
from storyboard.db.api import users as user_api
CONF = cfg.CONF
@ -41,7 +42,6 @@ class SkeletonValidator(RequestValidator):
def __init__(self):
super(SkeletonValidator, self).__init__()
self.token_storage = storage.get_storage()
def validate_client_id(self, client_id, request, *args, **kwargs):
"""Check that a valid client is connecting
@ -111,7 +111,7 @@ class SkeletonValidator(RequestValidator):
email = request._params["openid.sreg.email"]
full_name = request._params["openid.sreg.fullname"]
username = request._params["openid.sreg.nickname"]
last_login = datetime.utcnow()
last_login = datetime.datetime.utcnow()
user = user_api.user_get_by_openid(openid)
user_dict = {"full_name": full_name,
@ -125,7 +125,13 @@ class SkeletonValidator(RequestValidator):
else:
user = user_api.user_update(user.id, user_dict)
self.token_storage.save_authorization_code(code, user_id=user.id)
# def save_authorization_code(self, authorization_code, user_id):
values = {
"code": code["code"],
"state": code["state"],
"user_id": user.id
}
auth_api.authorization_code_save(values)
# Token request
@ -142,7 +148,8 @@ class SkeletonValidator(RequestValidator):
def validate_code(self, client_id, code, client, request, *args, **kwargs):
"""Validate the code belongs to the client."""
return self.token_storage.check_authorization_code(code)
db_code = auth_api.authorization_code_get(code)
return not db_code is None
def confirm_redirect_uri(self, client_id, code, redirect_uri, client,
*args, **kwargs):
@ -169,13 +176,12 @@ class SkeletonValidator(RequestValidator):
# Try authorization code
code = request._params.get("code")
if code:
code_info = self.token_storage.get_authorization_code_info(code)
code_info = auth_api.authorization_code_get(code)
return code_info.user_id
# Try refresh token
refresh_token = request._params.get("refresh_token")
refresh_token_entry = self.token_storage.get_refresh_token_info(
refresh_token)
refresh_token_entry = auth_api.refresh_token_get(refresh_token)
if refresh_token_entry:
return refresh_token_entry.user_id
@ -190,10 +196,28 @@ class SkeletonValidator(RequestValidator):
# be removed.
self.invalidate_refresh_token(request)
self.token_storage.save_token(access_token=token["access_token"],
expires_in=token["expires_in"],
refresh_token=token["refresh_token"],
user_id=user_id)
access_token_values = {
"access_token": token["access_token"],
"expires_in": token["expires_in"],
"expires_at": datetime.datetime.utcnow() + datetime.timedelta(
seconds=token["expires_in"]),
"user_id": user_id
}
# Oauthlib does not provide a separate expiration time for a
# refresh_token so taking it from config directly.
refresh_expires_in = CONF.oauth.refresh_token_ttl
refresh_token_values = {
"refresh_token": token["refresh_token"],
"user_id": user_id,
"expires_in": refresh_expires_in,
"expires_at": datetime.datetime.utcnow() + datetime.timedelta(
seconds=refresh_expires_in),
}
token_api.access_token_create(access_token_values)
auth_api.refresh_token_save(refresh_token_values)
def invalidate_authorization_code(self, client_id, code, request, *args,
**kwargs):
@ -202,7 +226,7 @@ class SkeletonValidator(RequestValidator):
"""
self.token_storage.invalidate_authorization_code(code)
auth_api.authorization_code_delete(code)
# Protected resource request
@ -229,7 +253,16 @@ class SkeletonValidator(RequestValidator):
**kwargs):
"""Check that the refresh token exists in the db."""
return self.token_storage.check_refresh_token(refresh_token)
refresh_token_entry = auth_api.refresh_token_get(refresh_token)
if not refresh_token_entry:
return False
if datetime.datetime.utcnow() > refresh_token_entry.expires_at:
auth_api.refresh_token_delete(refresh_token)
return False
return True
def invalidate_refresh_token(self, request):
"""Remove a used token from the storage."""
@ -241,7 +274,7 @@ class SkeletonValidator(RequestValidator):
if not refresh_token:
return
self.token_storage.invalidate_refresh_token(refresh_token)
auth_api.refresh_token_delete(refresh_token)
class OpenIdConnectServer(WebApplicationServer):

View File

@ -1,105 +0,0 @@
# Copyright (c) 2014 Mirantis Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
from oslo.config import cfg
from storyboard.api.auth.token_storage import storage
from storyboard.db.api import access_tokens as token_api
from storyboard.db.api import auth as auth_api
CONF = cfg.CONF
class DBTokenStorage(storage.StorageBase):
def save_authorization_code(self, authorization_code, user_id):
values = {
"code": authorization_code["code"],
"state": authorization_code["state"],
"user_id": user_id
}
auth_api.authorization_code_save(values)
def get_authorization_code_info(self, code):
return auth_api.authorization_code_get(code)
def check_authorization_code(self, code):
db_code = auth_api.authorization_code_get(code)
return not db_code is None
def invalidate_authorization_code(self, code):
auth_api.authorization_code_delete(code)
def save_token(self, access_token, expires_in, refresh_token, user_id):
access_token_values = {
"access_token": access_token,
"expires_in": expires_in,
"expires_at": datetime.datetime.utcnow() + datetime.timedelta(
seconds=expires_in),
"user_id": user_id
}
# Oauthlib does not provide a separate expiration time for a
# refresh_token so taking it from config directly.
refresh_expires_in = CONF.oauth.refresh_token_ttl
refresh_token_values = {
"refresh_token": refresh_token,
"user_id": user_id,
"expires_in": refresh_expires_in,
"expires_at": datetime.datetime.utcnow() + datetime.timedelta(
seconds=refresh_expires_in),
}
token_api.access_token_create(access_token_values)
auth_api.refresh_token_save(refresh_token_values)
def get_access_token_info(self, access_token):
return token_api.access_token_get_by_token(access_token)
def check_access_token(self, access_token):
token_info = token_api.access_token_get_by_token(access_token)
if not token_info:
return False
if datetime.datetime.utcnow() > token_info.expires_at:
token_api.access_token_delete(access_token)
return False
return True
def remove_token(self, access_token):
token_api.access_token_delete(access_token)
def check_refresh_token(self, refresh_token):
refresh_token_entry = auth_api.refresh_token_get(refresh_token)
if not refresh_token_entry:
return False
if datetime.datetime.utcnow() > refresh_token_entry.expires_at:
auth_api.refresh_token_delete(refresh_token)
return False
return True
def get_refresh_token_info(self, refresh_token):
return auth_api.refresh_token_get(refresh_token)
def invalidate_refresh_token(self, refresh_token):
auth_api.refresh_token_delete(refresh_token)

View File

@ -1,22 +0,0 @@
# Copyright (c) 2014 Mirantis Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from storyboard.api.auth.token_storage import db_storage
from storyboard.api.auth.token_storage import memory_storage
STORAGE_IMPLS = {
"mem": memory_storage.MemoryTokenStorage,
"db": db_storage.DBTokenStorage
}

View File

@ -1,130 +0,0 @@
# Copyright (c) 2014 Mirantis Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
from storyboard.api.auth.token_storage import storage
class Token(object):
def __init__(self, access_token, refresh_token, expires_in, user_id,
is_valid=True):
self.access_token = access_token
self.refresh_token = refresh_token
self.expires_in = expires_in
self.expires_at = datetime.datetime.utcnow() + \
datetime.timedelta(seconds=expires_in)
self.user_id = user_id
self.is_valid = is_valid
class AuthorizationCode(object):
def __init__(self, code, user_id):
self.code = code
self.user_id = user_id
class MemoryTokenStorage(storage.StorageBase):
def __init__(self):
self.token_set = set([])
self.auth_code_set = set([])
def save_token(self, access_token, expires_in, refresh_token, user_id):
token_info = Token(access_token=access_token,
expires_in=expires_in,
refresh_token=refresh_token,
user_id=user_id)
self.token_set.add(token_info)
def check_access_token(self, access_token):
token_entry = None
for token_info in self.token_set:
if token_info.access_token == access_token:
token_entry = token_info
if not token_entry:
return False
now = datetime.datetime.utcnow()
if now > token_entry.expires_at:
token_entry.is_valid = False
return False
return True
def get_access_token_info(self, access_token):
for token_info in self.token_set:
if token_info.access_token == access_token:
return token_info
return None
def remove_token(self, token):
pass
def check_refresh_token(self, refresh_token):
for token_info in self.token_set:
if token_info.refresh_token == refresh_token:
return True
return False
def get_refresh_token_info(self, refresh_token):
for token_info in self.token_set:
if token_info.refresh_token == refresh_token:
return token_info
return None
def invalidate_refresh_token(self, refresh_token):
token_entry = None
for entry in self.token_set:
if entry.refresh_token == refresh_token:
token_entry = entry
break
self.token_set.remove(token_entry)
def save_authorization_code(self, authorization_code, user_id):
self.auth_code_set.add(AuthorizationCode(authorization_code, user_id))
def check_authorization_code(self, code):
code_entry = None
for entry in self.auth_code_set:
if entry.code["code"] == code:
code_entry = entry
break
if not code_entry:
return False
return True
def get_authorization_code_info(self, code):
for entry in self.auth_code_set:
if entry.code["code"] == code:
return entry
return None
def invalidate_authorization_code(self, code):
code_entry = None
for entry in self.auth_code_set:
if entry.code["code"] == code:
code_entry = entry
break
self.auth_code_set.remove(code_entry)

View File

@ -1,153 +0,0 @@
# Copyright (c) 2014 Mirantis Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from oslo.config import cfg
CONF = cfg.CONF
STORAGE_OPTS = [
cfg.StrOpt('token_storage_type',
default='db',
help='Authorization token storage type.'
' Supported types are "mem" and "db".'
' Memory storage is not persistent between api launches')
]
CONF.register_opts(STORAGE_OPTS)
class StorageBase(object):
@abc.abstractmethod
def save_authorization_code(self, authorization_code, user_id):
"""This method should save an Authorization Code to the storage and
associate it with a user_id.
@param authorization_code: An object, containing state and a the code
itself.
@param user_id: The id of a User to associate the code with.
"""
pass
@abc.abstractmethod
def check_authorization_code(self, code):
"""Check that the given token exists in the storage.
@param code: The code to be checked.
@return bool
"""
pass
@abc.abstractmethod
def get_authorization_code_info(self, code):
"""Get the code info from the storage.
@param code: An authorization Code
@return object: The returned object should contain the state and the
user_id, which the given code is associated with.
"""
pass
@abc.abstractmethod
def invalidate_authorization_code(self, code):
"""Remove a code from the storage.
@param code: An authorization Code
"""
pass
@abc.abstractmethod
def save_token(self, access_token, expires_in, refresh_token, user_id):
"""Save a Bearer token to the storage with all associated fields
@param access_token: A token that will be used in authorized requests.
@param expires_in: The time in seconds while the access_token is valid.
@param refresh_token: A token that will be used in a refresh request
after an access_token gets expired.
@param user_id: The id of a User which owns a token.
"""
pass
@abc.abstractmethod
def check_access_token(self, access_token):
"""This method should say if a given token exists in the storage and
that it has not expired yet.
@param access_token: The token to be checked.
@return bool
"""
pass
@abc.abstractmethod
def get_access_token_info(self, access_token):
"""Get the Bearer token from the storage.
@param access_token: The token to get the information about.
@return object: The object should contain all fields associated with
the token (refresh_token, expires_in, user_id).
"""
pass
@abc.abstractmethod
def remove_token(self, token):
"""Invalidate a given token and remove it from the storage.
@param token: The token to be removed.
"""
pass
@abc.abstractmethod
def check_refresh_token(self, refresh_token):
"""This method should say if a given token exists in the storage and
that it has not expired yet.
@param refresh_token: The token to be checked.
@return bool
"""
pass
@abc.abstractmethod
def get_refresh_token_info(self, refresh_token):
"""Get the Bearer token from the storage.
@param refresh_token: The token to get the information about.
@return object: The object should contain all fields associated with
the token (refresh_token, expires_in, user_id).
"""
pass
@abc.abstractmethod
def invalidate_refresh_token(self, refresh_token):
"""Remove a token from the storage.
@param refresh_token: A refresh token
"""
pass
STORAGE = None
def get_storage():
global STORAGE
return STORAGE
def set_storage(impl):
global STORAGE
STORAGE = impl

View File

@ -4,7 +4,7 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
@ -15,7 +15,7 @@
from pecan import hooks
from storyboard.api.auth.token_storage import storage
from storyboard.db.api import access_tokens as token_api
class UserIdHook(hooks.PecanHook):
@ -24,11 +24,11 @@ class UserIdHook(hooks.PecanHook):
request = state.request
if request.authorization and len(request.authorization) == 2:
token = request.authorization[1]
token_info = storage.get_storage().get_access_token_info(token)
access_token = request.authorization[1]
token = token_api.access_token_get_by_token(access_token)
if token_info:
request.current_user_id = token_info.user_id
if token:
request.current_user_id = token.user_id
return
request.current_user_id = None

View File

@ -24,7 +24,7 @@ import six
from storyboard.api.auth.oauth_validator import SERVER
from storyboard.api.auth.openid_client import client as openid_client
from storyboard.api.auth.token_storage import storage
from storyboard.db.api import auth as auth_api
LOG = log.getLogger(__name__)
@ -73,8 +73,7 @@ class AuthController(rest.RestController):
def _access_token_by_code(self):
auth_code = request.params.get("code")
code_info = storage.get_storage() \
.get_authorization_code_info(auth_code)
code_info = auth_api.authorization_code_get(auth_code)
headers, body, code = SERVER.create_token_response(
uri=request.url,
http_method=request.method,
@ -96,8 +95,7 @@ class AuthController(rest.RestController):
def _access_token_by_refresh_token(self):
refresh_token = request.params.get("refresh_token")
refresh_token_info = storage.get_storage().get_refresh_token_info(
refresh_token)
refresh_token_info = auth_api.refresh_token_get(refresh_token)
headers, body, code = SERVER.create_token_response(
uri=request.url,

View File

@ -29,10 +29,26 @@ def access_token_get(access_token_id):
def access_token_get_by_token(access_token):
return api_base.model_query(models.AccessToken)\
return api_base.model_query(models.AccessToken) \
.filter_by(access_token=access_token).first()
def is_valid(access_token):
if not access_token:
return False
token = access_token_get_by_token(access_token)
if not token:
return False
if datetime.datetime.utcnow() > token.expires_at:
token.access_token_delete(token)
return False
return True
def access_token_get_all(marker=None, limit=None, sort_field=None,
sort_dir=None, **kwargs):
# Sanity checks, in case someone accidentally explicitly passes in 'None'