OAuth Storage Abstraction Removed
The DB storage mechanism was doubly abstracted, once via the token_storage mechanism, and once via the DB API. Since the token_storage mechanism was never properly instantiated during tests, and was therefore preventing us from running tests properly, I've removed it entirely and pulled all methods from db_storage either directly into the various validators, or into the DB abstraction. Change-Id: Ia3026d5fa6bc45269beb6cfbc237fc62e98fb329
This commit is contained in:
parent
8195049588
commit
fbaa9220ce
@ -21,8 +21,6 @@ from oslo_log import log
|
||||
import pecan
|
||||
from wsgiref import simple_server
|
||||
|
||||
from storyboard.api.auth.token_storage import impls as storage_impls
|
||||
from storyboard.api.auth.token_storage import storage
|
||||
from storyboard.api import config as api_config
|
||||
from storyboard.api.middleware.cors_middleware import CORSMiddleware
|
||||
from storyboard.api.middleware import token_middleware
|
||||
@ -87,11 +85,6 @@ def setup_app(pecan_config=None):
|
||||
validation_hook.ValidationHook()
|
||||
]
|
||||
|
||||
# Setup token storage
|
||||
token_storage_type = CONF.token_storage_type
|
||||
storage_cls = storage_impls.STORAGE_IMPLS[token_storage_type]
|
||||
storage.set_storage(storage_cls())
|
||||
|
||||
# Setup search engine
|
||||
search_engine_name = CONF.search_engine
|
||||
search_engine_cls = search_engine_impls.ENGINE_IMPLS[search_engine_name]
|
||||
|
@ -4,7 +4,7 @@
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@ -16,7 +16,7 @@
|
||||
from pecan import abort
|
||||
from pecan import request
|
||||
|
||||
from storyboard.api.auth.token_storage import storage
|
||||
from storyboard.db.api import access_tokens as token_api
|
||||
from storyboard.db.api import users as user_api
|
||||
from storyboard.openstack.common.gettextutils import _ # noqa
|
||||
|
||||
@ -29,7 +29,6 @@ def _get_token():
|
||||
|
||||
|
||||
def guest():
|
||||
token_storage = storage.get_storage()
|
||||
token = _get_token()
|
||||
|
||||
# Public resources do not require a token.
|
||||
@ -37,25 +36,23 @@ def guest():
|
||||
return True
|
||||
|
||||
# But if there is a token, it should be valid.
|
||||
return token_storage.check_access_token(token)
|
||||
return token_api.is_valid(token)
|
||||
|
||||
|
||||
def authenticated():
|
||||
token_storage = storage.get_storage()
|
||||
token = _get_token()
|
||||
|
||||
return token and token_storage.check_access_token(token)
|
||||
return token_api.is_valid(token)
|
||||
|
||||
|
||||
def superuser():
|
||||
token_storage = storage.get_storage()
|
||||
token = _get_token()
|
||||
|
||||
if not token:
|
||||
return False
|
||||
|
||||
token_info = token_storage.get_access_token_info(token)
|
||||
user = user_api.user_get(token_info.user_id)
|
||||
token = token_api.access_token_get_by_token(token)
|
||||
user = user_api.user_get(token.user_id)
|
||||
|
||||
if not user.is_superuser:
|
||||
abort(403, _("This action is limited to superusers only."))
|
||||
|
@ -4,7 +4,7 @@
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@ -13,14 +13,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from datetime import datetime
|
||||
import datetime
|
||||
|
||||
from oauthlib.oauth2 import RequestValidator
|
||||
from oauthlib.oauth2 import WebApplicationServer
|
||||
from oslo.config import cfg
|
||||
from oslo_log import log
|
||||
|
||||
from storyboard.api.auth.token_storage import storage
|
||||
from storyboard.db.api import access_tokens as token_api
|
||||
from storyboard.db.api import auth as auth_api
|
||||
from storyboard.db.api import users as user_api
|
||||
|
||||
CONF = cfg.CONF
|
||||
@ -41,7 +42,6 @@ class SkeletonValidator(RequestValidator):
|
||||
|
||||
def __init__(self):
|
||||
super(SkeletonValidator, self).__init__()
|
||||
self.token_storage = storage.get_storage()
|
||||
|
||||
def validate_client_id(self, client_id, request, *args, **kwargs):
|
||||
"""Check that a valid client is connecting
|
||||
@ -111,7 +111,7 @@ class SkeletonValidator(RequestValidator):
|
||||
email = request._params["openid.sreg.email"]
|
||||
full_name = request._params["openid.sreg.fullname"]
|
||||
username = request._params["openid.sreg.nickname"]
|
||||
last_login = datetime.utcnow()
|
||||
last_login = datetime.datetime.utcnow()
|
||||
|
||||
user = user_api.user_get_by_openid(openid)
|
||||
user_dict = {"full_name": full_name,
|
||||
@ -125,7 +125,13 @@ class SkeletonValidator(RequestValidator):
|
||||
else:
|
||||
user = user_api.user_update(user.id, user_dict)
|
||||
|
||||
self.token_storage.save_authorization_code(code, user_id=user.id)
|
||||
# def save_authorization_code(self, authorization_code, user_id):
|
||||
values = {
|
||||
"code": code["code"],
|
||||
"state": code["state"],
|
||||
"user_id": user.id
|
||||
}
|
||||
auth_api.authorization_code_save(values)
|
||||
|
||||
# Token request
|
||||
|
||||
@ -142,7 +148,8 @@ class SkeletonValidator(RequestValidator):
|
||||
def validate_code(self, client_id, code, client, request, *args, **kwargs):
|
||||
"""Validate the code belongs to the client."""
|
||||
|
||||
return self.token_storage.check_authorization_code(code)
|
||||
db_code = auth_api.authorization_code_get(code)
|
||||
return not db_code is None
|
||||
|
||||
def confirm_redirect_uri(self, client_id, code, redirect_uri, client,
|
||||
*args, **kwargs):
|
||||
@ -169,13 +176,12 @@ class SkeletonValidator(RequestValidator):
|
||||
# Try authorization code
|
||||
code = request._params.get("code")
|
||||
if code:
|
||||
code_info = self.token_storage.get_authorization_code_info(code)
|
||||
code_info = auth_api.authorization_code_get(code)
|
||||
return code_info.user_id
|
||||
|
||||
# Try refresh token
|
||||
refresh_token = request._params.get("refresh_token")
|
||||
refresh_token_entry = self.token_storage.get_refresh_token_info(
|
||||
refresh_token)
|
||||
refresh_token_entry = auth_api.refresh_token_get(refresh_token)
|
||||
if refresh_token_entry:
|
||||
return refresh_token_entry.user_id
|
||||
|
||||
@ -190,10 +196,28 @@ class SkeletonValidator(RequestValidator):
|
||||
# be removed.
|
||||
self.invalidate_refresh_token(request)
|
||||
|
||||
self.token_storage.save_token(access_token=token["access_token"],
|
||||
expires_in=token["expires_in"],
|
||||
refresh_token=token["refresh_token"],
|
||||
user_id=user_id)
|
||||
access_token_values = {
|
||||
"access_token": token["access_token"],
|
||||
"expires_in": token["expires_in"],
|
||||
"expires_at": datetime.datetime.utcnow() + datetime.timedelta(
|
||||
seconds=token["expires_in"]),
|
||||
"user_id": user_id
|
||||
}
|
||||
|
||||
# Oauthlib does not provide a separate expiration time for a
|
||||
# refresh_token so taking it from config directly.
|
||||
refresh_expires_in = CONF.oauth.refresh_token_ttl
|
||||
|
||||
refresh_token_values = {
|
||||
"refresh_token": token["refresh_token"],
|
||||
"user_id": user_id,
|
||||
"expires_in": refresh_expires_in,
|
||||
"expires_at": datetime.datetime.utcnow() + datetime.timedelta(
|
||||
seconds=refresh_expires_in),
|
||||
}
|
||||
|
||||
token_api.access_token_create(access_token_values)
|
||||
auth_api.refresh_token_save(refresh_token_values)
|
||||
|
||||
def invalidate_authorization_code(self, client_id, code, request, *args,
|
||||
**kwargs):
|
||||
@ -202,7 +226,7 @@ class SkeletonValidator(RequestValidator):
|
||||
|
||||
"""
|
||||
|
||||
self.token_storage.invalidate_authorization_code(code)
|
||||
auth_api.authorization_code_delete(code)
|
||||
|
||||
# Protected resource request
|
||||
|
||||
@ -229,7 +253,16 @@ class SkeletonValidator(RequestValidator):
|
||||
**kwargs):
|
||||
"""Check that the refresh token exists in the db."""
|
||||
|
||||
return self.token_storage.check_refresh_token(refresh_token)
|
||||
refresh_token_entry = auth_api.refresh_token_get(refresh_token)
|
||||
|
||||
if not refresh_token_entry:
|
||||
return False
|
||||
|
||||
if datetime.datetime.utcnow() > refresh_token_entry.expires_at:
|
||||
auth_api.refresh_token_delete(refresh_token)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def invalidate_refresh_token(self, request):
|
||||
"""Remove a used token from the storage."""
|
||||
@ -241,7 +274,7 @@ class SkeletonValidator(RequestValidator):
|
||||
if not refresh_token:
|
||||
return
|
||||
|
||||
self.token_storage.invalidate_refresh_token(refresh_token)
|
||||
auth_api.refresh_token_delete(refresh_token)
|
||||
|
||||
|
||||
class OpenIdConnectServer(WebApplicationServer):
|
||||
|
@ -1,105 +0,0 @@
|
||||
# Copyright (c) 2014 Mirantis Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
|
||||
from oslo.config import cfg
|
||||
|
||||
from storyboard.api.auth.token_storage import storage
|
||||
from storyboard.db.api import access_tokens as token_api
|
||||
from storyboard.db.api import auth as auth_api
|
||||
|
||||
|
||||
CONF = cfg.CONF
|
||||
|
||||
|
||||
class DBTokenStorage(storage.StorageBase):
|
||||
def save_authorization_code(self, authorization_code, user_id):
|
||||
values = {
|
||||
"code": authorization_code["code"],
|
||||
"state": authorization_code["state"],
|
||||
"user_id": user_id
|
||||
}
|
||||
auth_api.authorization_code_save(values)
|
||||
|
||||
def get_authorization_code_info(self, code):
|
||||
return auth_api.authorization_code_get(code)
|
||||
|
||||
def check_authorization_code(self, code):
|
||||
db_code = auth_api.authorization_code_get(code)
|
||||
return not db_code is None
|
||||
|
||||
def invalidate_authorization_code(self, code):
|
||||
auth_api.authorization_code_delete(code)
|
||||
|
||||
def save_token(self, access_token, expires_in, refresh_token, user_id):
|
||||
access_token_values = {
|
||||
"access_token": access_token,
|
||||
"expires_in": expires_in,
|
||||
"expires_at": datetime.datetime.utcnow() + datetime.timedelta(
|
||||
seconds=expires_in),
|
||||
"user_id": user_id
|
||||
}
|
||||
|
||||
# Oauthlib does not provide a separate expiration time for a
|
||||
# refresh_token so taking it from config directly.
|
||||
refresh_expires_in = CONF.oauth.refresh_token_ttl
|
||||
|
||||
refresh_token_values = {
|
||||
"refresh_token": refresh_token,
|
||||
"user_id": user_id,
|
||||
"expires_in": refresh_expires_in,
|
||||
"expires_at": datetime.datetime.utcnow() + datetime.timedelta(
|
||||
seconds=refresh_expires_in),
|
||||
}
|
||||
|
||||
token_api.access_token_create(access_token_values)
|
||||
auth_api.refresh_token_save(refresh_token_values)
|
||||
|
||||
def get_access_token_info(self, access_token):
|
||||
return token_api.access_token_get_by_token(access_token)
|
||||
|
||||
def check_access_token(self, access_token):
|
||||
token_info = token_api.access_token_get_by_token(access_token)
|
||||
|
||||
if not token_info:
|
||||
return False
|
||||
|
||||
if datetime.datetime.utcnow() > token_info.expires_at:
|
||||
token_api.access_token_delete(access_token)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def remove_token(self, access_token):
|
||||
token_api.access_token_delete(access_token)
|
||||
|
||||
def check_refresh_token(self, refresh_token):
|
||||
refresh_token_entry = auth_api.refresh_token_get(refresh_token)
|
||||
|
||||
if not refresh_token_entry:
|
||||
return False
|
||||
|
||||
if datetime.datetime.utcnow() > refresh_token_entry.expires_at:
|
||||
auth_api.refresh_token_delete(refresh_token)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_refresh_token_info(self, refresh_token):
|
||||
return auth_api.refresh_token_get(refresh_token)
|
||||
|
||||
def invalidate_refresh_token(self, refresh_token):
|
||||
auth_api.refresh_token_delete(refresh_token)
|
@ -1,22 +0,0 @@
|
||||
# Copyright (c) 2014 Mirantis Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from storyboard.api.auth.token_storage import db_storage
|
||||
from storyboard.api.auth.token_storage import memory_storage
|
||||
|
||||
STORAGE_IMPLS = {
|
||||
"mem": memory_storage.MemoryTokenStorage,
|
||||
"db": db_storage.DBTokenStorage
|
||||
}
|
@ -1,130 +0,0 @@
|
||||
# Copyright (c) 2014 Mirantis Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
|
||||
from storyboard.api.auth.token_storage import storage
|
||||
|
||||
|
||||
class Token(object):
|
||||
def __init__(self, access_token, refresh_token, expires_in, user_id,
|
||||
is_valid=True):
|
||||
self.access_token = access_token
|
||||
self.refresh_token = refresh_token
|
||||
self.expires_in = expires_in
|
||||
self.expires_at = datetime.datetime.utcnow() + \
|
||||
datetime.timedelta(seconds=expires_in)
|
||||
self.user_id = user_id
|
||||
self.is_valid = is_valid
|
||||
|
||||
|
||||
class AuthorizationCode(object):
|
||||
def __init__(self, code, user_id):
|
||||
self.code = code
|
||||
self.user_id = user_id
|
||||
|
||||
|
||||
class MemoryTokenStorage(storage.StorageBase):
|
||||
|
||||
def __init__(self):
|
||||
self.token_set = set([])
|
||||
self.auth_code_set = set([])
|
||||
|
||||
def save_token(self, access_token, expires_in, refresh_token, user_id):
|
||||
token_info = Token(access_token=access_token,
|
||||
expires_in=expires_in,
|
||||
refresh_token=refresh_token,
|
||||
user_id=user_id)
|
||||
|
||||
self.token_set.add(token_info)
|
||||
|
||||
def check_access_token(self, access_token):
|
||||
token_entry = None
|
||||
for token_info in self.token_set:
|
||||
if token_info.access_token == access_token:
|
||||
token_entry = token_info
|
||||
|
||||
if not token_entry:
|
||||
return False
|
||||
|
||||
now = datetime.datetime.utcnow()
|
||||
if now > token_entry.expires_at:
|
||||
token_entry.is_valid = False
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_access_token_info(self, access_token):
|
||||
for token_info in self.token_set:
|
||||
if token_info.access_token == access_token:
|
||||
return token_info
|
||||
return None
|
||||
|
||||
def remove_token(self, token):
|
||||
pass
|
||||
|
||||
def check_refresh_token(self, refresh_token):
|
||||
for token_info in self.token_set:
|
||||
if token_info.refresh_token == refresh_token:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_refresh_token_info(self, refresh_token):
|
||||
for token_info in self.token_set:
|
||||
if token_info.refresh_token == refresh_token:
|
||||
return token_info
|
||||
|
||||
return None
|
||||
|
||||
def invalidate_refresh_token(self, refresh_token):
|
||||
token_entry = None
|
||||
for entry in self.token_set:
|
||||
if entry.refresh_token == refresh_token:
|
||||
token_entry = entry
|
||||
break
|
||||
|
||||
self.token_set.remove(token_entry)
|
||||
|
||||
def save_authorization_code(self, authorization_code, user_id):
|
||||
self.auth_code_set.add(AuthorizationCode(authorization_code, user_id))
|
||||
|
||||
def check_authorization_code(self, code):
|
||||
code_entry = None
|
||||
for entry in self.auth_code_set:
|
||||
if entry.code["code"] == code:
|
||||
code_entry = entry
|
||||
break
|
||||
|
||||
if not code_entry:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_authorization_code_info(self, code):
|
||||
for entry in self.auth_code_set:
|
||||
if entry.code["code"] == code:
|
||||
return entry
|
||||
|
||||
return None
|
||||
|
||||
def invalidate_authorization_code(self, code):
|
||||
code_entry = None
|
||||
for entry in self.auth_code_set:
|
||||
if entry.code["code"] == code:
|
||||
code_entry = entry
|
||||
break
|
||||
|
||||
self.auth_code_set.remove(code_entry)
|
@ -1,153 +0,0 @@
|
||||
# Copyright (c) 2014 Mirantis Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
|
||||
from oslo.config import cfg
|
||||
|
||||
CONF = cfg.CONF
|
||||
|
||||
STORAGE_OPTS = [
|
||||
cfg.StrOpt('token_storage_type',
|
||||
default='db',
|
||||
help='Authorization token storage type.'
|
||||
' Supported types are "mem" and "db".'
|
||||
' Memory storage is not persistent between api launches')
|
||||
]
|
||||
|
||||
CONF.register_opts(STORAGE_OPTS)
|
||||
|
||||
|
||||
class StorageBase(object):
|
||||
|
||||
@abc.abstractmethod
|
||||
def save_authorization_code(self, authorization_code, user_id):
|
||||
"""This method should save an Authorization Code to the storage and
|
||||
associate it with a user_id.
|
||||
|
||||
@param authorization_code: An object, containing state and a the code
|
||||
itself.
|
||||
@param user_id: The id of a User to associate the code with.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def check_authorization_code(self, code):
|
||||
"""Check that the given token exists in the storage.
|
||||
|
||||
@param code: The code to be checked.
|
||||
@return bool
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_authorization_code_info(self, code):
|
||||
"""Get the code info from the storage.
|
||||
|
||||
@param code: An authorization Code
|
||||
|
||||
@return object: The returned object should contain the state and the
|
||||
user_id, which the given code is associated with.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def invalidate_authorization_code(self, code):
|
||||
"""Remove a code from the storage.
|
||||
|
||||
@param code: An authorization Code
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def save_token(self, access_token, expires_in, refresh_token, user_id):
|
||||
"""Save a Bearer token to the storage with all associated fields
|
||||
|
||||
@param access_token: A token that will be used in authorized requests.
|
||||
@param expires_in: The time in seconds while the access_token is valid.
|
||||
@param refresh_token: A token that will be used in a refresh request
|
||||
after an access_token gets expired.
|
||||
@param user_id: The id of a User which owns a token.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def check_access_token(self, access_token):
|
||||
"""This method should say if a given token exists in the storage and
|
||||
that it has not expired yet.
|
||||
|
||||
@param access_token: The token to be checked.
|
||||
@return bool
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_access_token_info(self, access_token):
|
||||
"""Get the Bearer token from the storage.
|
||||
|
||||
@param access_token: The token to get the information about.
|
||||
@return object: The object should contain all fields associated with
|
||||
the token (refresh_token, expires_in, user_id).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def remove_token(self, token):
|
||||
"""Invalidate a given token and remove it from the storage.
|
||||
|
||||
@param token: The token to be removed.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def check_refresh_token(self, refresh_token):
|
||||
"""This method should say if a given token exists in the storage and
|
||||
that it has not expired yet.
|
||||
|
||||
@param refresh_token: The token to be checked.
|
||||
@return bool
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_refresh_token_info(self, refresh_token):
|
||||
"""Get the Bearer token from the storage.
|
||||
|
||||
@param refresh_token: The token to get the information about.
|
||||
@return object: The object should contain all fields associated with
|
||||
the token (refresh_token, expires_in, user_id).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def invalidate_refresh_token(self, refresh_token):
|
||||
"""Remove a token from the storage.
|
||||
|
||||
@param refresh_token: A refresh token
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
STORAGE = None
|
||||
|
||||
|
||||
def get_storage():
|
||||
global STORAGE
|
||||
return STORAGE
|
||||
|
||||
|
||||
def set_storage(impl):
|
||||
global STORAGE
|
||||
STORAGE = impl
|
@ -4,7 +4,7 @@
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@ -15,7 +15,7 @@
|
||||
|
||||
from pecan import hooks
|
||||
|
||||
from storyboard.api.auth.token_storage import storage
|
||||
from storyboard.db.api import access_tokens as token_api
|
||||
|
||||
|
||||
class UserIdHook(hooks.PecanHook):
|
||||
@ -24,11 +24,11 @@ class UserIdHook(hooks.PecanHook):
|
||||
request = state.request
|
||||
|
||||
if request.authorization and len(request.authorization) == 2:
|
||||
token = request.authorization[1]
|
||||
token_info = storage.get_storage().get_access_token_info(token)
|
||||
access_token = request.authorization[1]
|
||||
token = token_api.access_token_get_by_token(access_token)
|
||||
|
||||
if token_info:
|
||||
request.current_user_id = token_info.user_id
|
||||
if token:
|
||||
request.current_user_id = token.user_id
|
||||
return
|
||||
|
||||
request.current_user_id = None
|
||||
|
@ -24,7 +24,7 @@ import six
|
||||
|
||||
from storyboard.api.auth.oauth_validator import SERVER
|
||||
from storyboard.api.auth.openid_client import client as openid_client
|
||||
from storyboard.api.auth.token_storage import storage
|
||||
from storyboard.db.api import auth as auth_api
|
||||
|
||||
LOG = log.getLogger(__name__)
|
||||
|
||||
@ -73,8 +73,7 @@ class AuthController(rest.RestController):
|
||||
|
||||
def _access_token_by_code(self):
|
||||
auth_code = request.params.get("code")
|
||||
code_info = storage.get_storage() \
|
||||
.get_authorization_code_info(auth_code)
|
||||
code_info = auth_api.authorization_code_get(auth_code)
|
||||
headers, body, code = SERVER.create_token_response(
|
||||
uri=request.url,
|
||||
http_method=request.method,
|
||||
@ -96,8 +95,7 @@ class AuthController(rest.RestController):
|
||||
|
||||
def _access_token_by_refresh_token(self):
|
||||
refresh_token = request.params.get("refresh_token")
|
||||
refresh_token_info = storage.get_storage().get_refresh_token_info(
|
||||
refresh_token)
|
||||
refresh_token_info = auth_api.refresh_token_get(refresh_token)
|
||||
|
||||
headers, body, code = SERVER.create_token_response(
|
||||
uri=request.url,
|
||||
|
@ -29,10 +29,26 @@ def access_token_get(access_token_id):
|
||||
|
||||
|
||||
def access_token_get_by_token(access_token):
|
||||
return api_base.model_query(models.AccessToken)\
|
||||
return api_base.model_query(models.AccessToken) \
|
||||
.filter_by(access_token=access_token).first()
|
||||
|
||||
|
||||
def is_valid(access_token):
|
||||
if not access_token:
|
||||
return False
|
||||
|
||||
token = access_token_get_by_token(access_token)
|
||||
|
||||
if not token:
|
||||
return False
|
||||
|
||||
if datetime.datetime.utcnow() > token.expires_at:
|
||||
token.access_token_delete(token)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def access_token_get_all(marker=None, limit=None, sort_field=None,
|
||||
sort_dir=None, **kwargs):
|
||||
# Sanity checks, in case someone accidentally explicitly passes in 'None'
|
||||
|
Loading…
x
Reference in New Issue
Block a user