OAuth delete refresh tokens when access tokens deleted.
Created association between refresh and access tokens, and linked them so that deleting the former will remove the latter. Additional tests added for the token cleaner, as well as the OAuth endpoint. Change-Id: I6b78bda9a21a6d6aa7bf93a9e6dad2e04cff0f44
This commit is contained in:
parent
e8ba4df9b3
commit
29dca914aa
@ -58,7 +58,7 @@ class SkeletonValidator(RequestValidator):
|
||||
|
||||
"""
|
||||
|
||||
#todo(nkonovalov): check an uri based on CONF.domain
|
||||
# todo(nkonovalov): check an uri based on CONF.domain
|
||||
return True
|
||||
|
||||
def get_default_redirect_uri(self, client_id, request, *args, **kwargs):
|
||||
@ -168,7 +168,7 @@ class SkeletonValidator(RequestValidator):
|
||||
|
||||
"""
|
||||
|
||||
#todo(nkonovalov): check an uri based on CONF.domain
|
||||
# todo(nkonovalov): check an uri based on CONF.domain
|
||||
return True
|
||||
|
||||
def validate_grant_type(self, client_id, grant_type, client, request,
|
||||
@ -202,8 +202,8 @@ class SkeletonValidator(RequestValidator):
|
||||
|
||||
user_id = self._resolve_user_id(request)
|
||||
|
||||
# If a refresh_token was used to obtain a new access_token, it should
|
||||
# be removed.
|
||||
# If a refresh_token was used to obtain a new access_token, it (and
|
||||
# its access token) should be removed.
|
||||
self.invalidate_refresh_token(request)
|
||||
|
||||
access_token_values = {
|
||||
@ -213,20 +213,20 @@ class SkeletonValidator(RequestValidator):
|
||||
seconds=token["expires_in"]),
|
||||
"user_id": user_id
|
||||
}
|
||||
access_token = token_api.access_token_create(access_token_values)
|
||||
|
||||
# Oauthlib does not provide a separate expiration time for a
|
||||
# refresh_token so taking it from config directly.
|
||||
refresh_expires_in = CONF.oauth.refresh_token_ttl
|
||||
|
||||
refresh_token_values = {
|
||||
"access_token_id": access_token.id,
|
||||
"refresh_token": token["refresh_token"],
|
||||
"user_id": user_id,
|
||||
"expires_in": refresh_expires_in,
|
||||
"expires_at": datetime.datetime.utcnow() + datetime.timedelta(
|
||||
seconds=refresh_expires_in),
|
||||
}
|
||||
|
||||
token_api.access_token_create(access_token_values)
|
||||
auth_api.refresh_token_save(refresh_token_values)
|
||||
|
||||
def invalidate_authorization_code(self, client_id, code, request, *args,
|
||||
@ -284,7 +284,8 @@ class SkeletonValidator(RequestValidator):
|
||||
if not refresh_token:
|
||||
return
|
||||
|
||||
auth_api.refresh_token_delete(refresh_token)
|
||||
r_token = auth_api.refresh_token_get(refresh_token)
|
||||
token_api.access_token_delete(r_token.access_token_id) # Cascades
|
||||
|
||||
|
||||
class OpenIdConnectServer(WebApplicationServer):
|
||||
|
@ -0,0 +1,40 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
"""This migration creates a reference between oauth access tokens and refresh
|
||||
tokens.
|
||||
|
||||
Revision ID: 034
|
||||
Revises: 033
|
||||
Create Date: 2015-02-06 12:00:00
|
||||
|
||||
"""
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
revision = '034'
|
||||
down_revision = '033'
|
||||
|
||||
from alembic import op
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import Integer
|
||||
|
||||
|
||||
def upgrade(active_plugins=None, options=None):
|
||||
op.add_column('refreshtokens', Column('access_token_id',
|
||||
Integer(),
|
||||
nullable=False))
|
||||
|
||||
|
||||
def downgrade(active_plugins=None, options=None):
|
||||
op.drop_column('refreshtokens', 'access_token_id')
|
@ -51,7 +51,7 @@ def table_args():
|
||||
'mysql_charset': "utf8"}
|
||||
return None
|
||||
|
||||
## CUSTOM TYPES
|
||||
# # CUSTOM TYPES
|
||||
|
||||
# A mysql medium text type.
|
||||
MYSQL_MEDIUM_TEXT = UnicodeText().with_variant(MEDIUMTEXT(), 'mysql')
|
||||
@ -317,10 +317,17 @@ class AccessToken(ModelBuilder, Base):
|
||||
nullable=False)
|
||||
expires_in = Column(Integer, nullable=False)
|
||||
expires_at = Column(DateTime, nullable=False)
|
||||
refresh_tokens = relationship("RefreshToken",
|
||||
cascade="save-update, merge, delete",
|
||||
passive_updates=False,
|
||||
passive_deletes=False)
|
||||
|
||||
|
||||
class RefreshToken(ModelBuilder, Base):
|
||||
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
|
||||
access_token_id = Column(Integer,
|
||||
ForeignKey('accesstokens.id'),
|
||||
nullable=False)
|
||||
refresh_token = Column(Unicode(CommonLength.top_middle_length),
|
||||
nullable=False)
|
||||
expires_in = Column(Integer, nullable=False)
|
||||
@ -335,7 +342,7 @@ def _story_build_summary_query():
|
||||
expr.case(
|
||||
[(func.sum(Task.status.in_(
|
||||
['todo', 'inprogress', 'review'])) > 0,
|
||||
'active'),
|
||||
'active'),
|
||||
((func.sum(Task.status == 'merged')) > 0, 'merged')],
|
||||
else_='invalid'
|
||||
).label('status')
|
||||
@ -347,8 +354,8 @@ def _story_build_summary_query():
|
||||
select_items.append(expr.null().label('task_statuses'))
|
||||
|
||||
result = select(select_items, None,
|
||||
expr.Join(Story, Task, onclause=Story.id == Task.story_id,
|
||||
isouter=True)) \
|
||||
expr.Join(Story, Task, onclause=Story.id == Task.story_id,
|
||||
isouter=True)) \
|
||||
.group_by(Story.id) \
|
||||
.alias('story_summary')
|
||||
|
||||
|
@ -52,10 +52,12 @@ class TokenCleaner(CronPluginBase):
|
||||
lastweek = datetime.utcnow() - timedelta(weeks=1)
|
||||
|
||||
# Build the query.
|
||||
query = api_base.model_query(AccessToken)
|
||||
query = api_base.model_query(AccessToken.id)
|
||||
|
||||
# Apply the filter.
|
||||
query = query.filter(AccessToken.expires_at < lastweek)
|
||||
|
||||
# Batch delete.
|
||||
query.delete()
|
||||
# Manually deleting each record, because batch deletes are an
|
||||
# exception to ORM Cascade markup.
|
||||
for token in query.all():
|
||||
api_base.entity_hard_delete(AccessToken, token.id)
|
||||
|
@ -642,6 +642,100 @@ class TestOAuthAccessToken(BaseOAuthTest):
|
||||
self.assertEqual(e_msg.INVALID_TOKEN_GRANT_TYPE,
|
||||
response.json['error_description'])
|
||||
|
||||
def test_valid_refresh_token(self):
|
||||
"""This test ensures that a valid refresh token can be converted into
|
||||
a valid access token, and cleans up after itself.
|
||||
"""
|
||||
|
||||
# Generate a valid access code
|
||||
authorization_code = auth_api.authorization_code_save({
|
||||
'user_id': 2,
|
||||
'state': 'test_state',
|
||||
'code': 'test_valid_code'
|
||||
})
|
||||
|
||||
# Generate an auth and a refresh token.
|
||||
resp_1 = self.app.post('/v1/openid/token',
|
||||
params={
|
||||
'code': authorization_code.code,
|
||||
'grant_type': 'authorization_code'
|
||||
},
|
||||
content_type=
|
||||
'application/x-www-form-urlencoded',
|
||||
expect_errors=True)
|
||||
|
||||
# Assert that this is a successful response
|
||||
self.assertEqual(200, resp_1.status_code)
|
||||
|
||||
# Assert that the token came back in the response
|
||||
t1 = resp_1.json
|
||||
|
||||
# Assert that both are in the database.
|
||||
access_token = \
|
||||
token_api.access_token_get_by_token(t1['access_token'])
|
||||
self.assertIsNotNone(access_token)
|
||||
refresh_token = \
|
||||
auth_api.refresh_token_get(t1['refresh_token'])
|
||||
self.assertIsNotNone(refresh_token)
|
||||
|
||||
# Issue a refresh token request.
|
||||
|
||||
resp_2 = self.app.post('/v1/openid/token',
|
||||
params={
|
||||
'refresh_token': t1['refresh_token'],
|
||||
'grant_type': 'refresh_token'
|
||||
},
|
||||
content_type=
|
||||
'application/x-www-form-urlencoded',
|
||||
expect_errors=True)
|
||||
|
||||
# Assert that the response is good.
|
||||
self.assertEqual(200, resp_2.status_code)
|
||||
|
||||
# Assert that the token came back in the response
|
||||
t2 = resp_2.json
|
||||
self.assertIsNotNone(t2['access_token'])
|
||||
self.assertIsNotNone(t2['expires_in'])
|
||||
self.assertIsNotNone(t2['id_token'])
|
||||
self.assertIsNotNone(t2['refresh_token'])
|
||||
self.assertIsNotNone(t2['token_type'])
|
||||
self.assertEqual('Bearer', t2['token_type'])
|
||||
|
||||
# Assert that the access token is in the database
|
||||
new_access_token = \
|
||||
token_api.access_token_get_by_token(t2['access_token'])
|
||||
self.assertIsNotNone(new_access_token)
|
||||
|
||||
# Assert that system configured values is owned by the correct user.
|
||||
self.assertEquals(2, new_access_token.user_id)
|
||||
self.assertEquals(t2['id_token'], new_access_token.user_id)
|
||||
self.assertEqual(t2['expires_in'], CONF.oauth.access_token_ttl)
|
||||
self.assertEqual(t2['expires_in'], new_access_token.expires_in)
|
||||
self.assertEqual(t2['access_token'],
|
||||
new_access_token.access_token)
|
||||
|
||||
# Assert that the refresh token is in the database
|
||||
new_refresh_token = \
|
||||
auth_api.refresh_token_get(t2['refresh_token'])
|
||||
self.assertIsNotNone(new_refresh_token)
|
||||
|
||||
# Assert that system configured values is owned by the correct user.
|
||||
self.assertEquals(2, new_refresh_token.user_id)
|
||||
self.assertEqual(CONF.oauth.refresh_token_ttl,
|
||||
new_refresh_token.expires_in)
|
||||
self.assertEqual(t2['refresh_token'],
|
||||
new_refresh_token.refresh_token)
|
||||
|
||||
# Assert that the old access tokens are no longer in the database and
|
||||
# have been cleaned up.
|
||||
no_access_token = \
|
||||
token_api.access_token_get_by_token(t1['access_token'])
|
||||
no_refresh_token = \
|
||||
auth_api.refresh_token_get(t1['refresh_token'])
|
||||
|
||||
self.assertIsNone(no_refresh_token)
|
||||
self.assertIsNone(no_access_token)
|
||||
|
||||
def test_invalid_refresh_token(self):
|
||||
"""This test ensures that an invalid refresh token can be converted
|
||||
into a valid access token.
|
||||
|
@ -18,6 +18,7 @@ from datetime import timedelta
|
||||
from oslo.config import cfg
|
||||
import storyboard.db.api.base as db_api
|
||||
from storyboard.db.models import AccessToken
|
||||
from storyboard.db.models import RefreshToken
|
||||
from storyboard.plugin.token_cleaner.cleaner import TokenCleaner
|
||||
import storyboard.tests.base as base
|
||||
from storyboard.tests.mock_data import load_data
|
||||
@ -61,22 +62,38 @@ class TestTokenCleaner(base.FunctionalTest, base.WorkingDirTestCase):
|
||||
# expiration dates. I subtract 5 seconds here because the time it
|
||||
# takes to execute the script may, or may not, result in an
|
||||
# 8-day-old-token to remain valid.
|
||||
new_access_tokens = []
|
||||
new_refresh_tokens = []
|
||||
for i in range(0, 100):
|
||||
created_at = datetime.utcnow() - timedelta(days=i)
|
||||
expires_in = (60 * 60 * 24) - 5 # Minus five seconds, see above.
|
||||
expires_at = created_at + timedelta(seconds=expires_in)
|
||||
|
||||
load_data([
|
||||
new_access_tokens.append(
|
||||
AccessToken(
|
||||
user_id=1,
|
||||
created_at=created_at.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
expires_in=expires_in,
|
||||
expires_at=expires_at.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
access_token='test_token_%s' % (i,))
|
||||
])
|
||||
)
|
||||
new_access_tokens = load_data(new_access_tokens)
|
||||
|
||||
for token in new_access_tokens:
|
||||
new_refresh_tokens.append(
|
||||
RefreshToken(
|
||||
user_id=1,
|
||||
access_token_id=token.id,
|
||||
created_at=token.created_at,
|
||||
expires_at=token.expires_at,
|
||||
expires_in=300,
|
||||
refresh_token='test_refresh_%s' % (token.id,))
|
||||
)
|
||||
new_refresh_tokens = load_data(new_refresh_tokens)
|
||||
|
||||
# Make sure we have 100 tokens.
|
||||
self.assertEqual(100, db_api.model_query(AccessToken).count())
|
||||
self.assertEqual(100, db_api.model_query(RefreshToken).count())
|
||||
|
||||
# Run the plugin.
|
||||
plugin = TokenCleaner(CONF)
|
||||
@ -84,3 +101,4 @@ class TestTokenCleaner(base.FunctionalTest, base.WorkingDirTestCase):
|
||||
|
||||
# Make sure we have 8 tokens left (since one plugin starts today).
|
||||
self.assertEqual(8, db_api.model_query(AccessToken).count())
|
||||
self.assertEqual(8, db_api.model_query(RefreshToken).count())
|
||||
|
Loading…
x
Reference in New Issue
Block a user