OAuth delete refresh tokens when access tokens deleted.

Created association between refresh and access tokens, and linked them
so that deleting the former will remove the latter. Additional tests
added for the token cleaner, as well as the OAuth endpoint.

Change-Id: I6b78bda9a21a6d6aa7bf93a9e6dad2e04cff0f44
This commit is contained in:
Michael Krotscheck 2015-02-06 14:02:02 -08:00
parent e8ba4df9b3
commit 29dca914aa
6 changed files with 178 additions and 16 deletions

View File

@ -58,7 +58,7 @@ class SkeletonValidator(RequestValidator):
"""
#todo(nkonovalov): check an uri based on CONF.domain
# todo(nkonovalov): check an uri based on CONF.domain
return True
def get_default_redirect_uri(self, client_id, request, *args, **kwargs):
@ -168,7 +168,7 @@ class SkeletonValidator(RequestValidator):
"""
#todo(nkonovalov): check an uri based on CONF.domain
# todo(nkonovalov): check an uri based on CONF.domain
return True
def validate_grant_type(self, client_id, grant_type, client, request,
@ -202,8 +202,8 @@ class SkeletonValidator(RequestValidator):
user_id = self._resolve_user_id(request)
# If a refresh_token was used to obtain a new access_token, it should
# be removed.
# If a refresh_token was used to obtain a new access_token, it (and
# its access token) should be removed.
self.invalidate_refresh_token(request)
access_token_values = {
@ -213,20 +213,20 @@ class SkeletonValidator(RequestValidator):
seconds=token["expires_in"]),
"user_id": user_id
}
access_token = token_api.access_token_create(access_token_values)
# Oauthlib does not provide a separate expiration time for a
# refresh_token so taking it from config directly.
refresh_expires_in = CONF.oauth.refresh_token_ttl
refresh_token_values = {
"access_token_id": access_token.id,
"refresh_token": token["refresh_token"],
"user_id": user_id,
"expires_in": refresh_expires_in,
"expires_at": datetime.datetime.utcnow() + datetime.timedelta(
seconds=refresh_expires_in),
}
token_api.access_token_create(access_token_values)
auth_api.refresh_token_save(refresh_token_values)
def invalidate_authorization_code(self, client_id, code, request, *args,
@ -284,7 +284,8 @@ class SkeletonValidator(RequestValidator):
if not refresh_token:
return
auth_api.refresh_token_delete(refresh_token)
r_token = auth_api.refresh_token_get(refresh_token)
token_api.access_token_delete(r_token.access_token_id) # Cascades
class OpenIdConnectServer(WebApplicationServer):

View File

@ -0,0 +1,40 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
"""This migration creates a reference between oauth access tokens and refresh
tokens.
Revision ID: 034
Revises: 033
Create Date: 2015-02-06 12:00:00
"""
# revision identifiers, used by Alembic.
revision = '034'
down_revision = '033'
from alembic import op
from sqlalchemy import Column
from sqlalchemy import Integer
def upgrade(active_plugins=None, options=None):
op.add_column('refreshtokens', Column('access_token_id',
Integer(),
nullable=False))
def downgrade(active_plugins=None, options=None):
op.drop_column('refreshtokens', 'access_token_id')

View File

@ -51,7 +51,7 @@ def table_args():
'mysql_charset': "utf8"}
return None
## CUSTOM TYPES
# # CUSTOM TYPES
# A mysql medium text type.
MYSQL_MEDIUM_TEXT = UnicodeText().with_variant(MEDIUMTEXT(), 'mysql')
@ -317,10 +317,17 @@ class AccessToken(ModelBuilder, Base):
nullable=False)
expires_in = Column(Integer, nullable=False)
expires_at = Column(DateTime, nullable=False)
refresh_tokens = relationship("RefreshToken",
cascade="save-update, merge, delete",
passive_updates=False,
passive_deletes=False)
class RefreshToken(ModelBuilder, Base):
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
access_token_id = Column(Integer,
ForeignKey('accesstokens.id'),
nullable=False)
refresh_token = Column(Unicode(CommonLength.top_middle_length),
nullable=False)
expires_in = Column(Integer, nullable=False)
@ -335,7 +342,7 @@ def _story_build_summary_query():
expr.case(
[(func.sum(Task.status.in_(
['todo', 'inprogress', 'review'])) > 0,
'active'),
'active'),
((func.sum(Task.status == 'merged')) > 0, 'merged')],
else_='invalid'
).label('status')
@ -347,8 +354,8 @@ def _story_build_summary_query():
select_items.append(expr.null().label('task_statuses'))
result = select(select_items, None,
expr.Join(Story, Task, onclause=Story.id == Task.story_id,
isouter=True)) \
expr.Join(Story, Task, onclause=Story.id == Task.story_id,
isouter=True)) \
.group_by(Story.id) \
.alias('story_summary')

View File

@ -52,10 +52,12 @@ class TokenCleaner(CronPluginBase):
lastweek = datetime.utcnow() - timedelta(weeks=1)
# Build the query.
query = api_base.model_query(AccessToken)
query = api_base.model_query(AccessToken.id)
# Apply the filter.
query = query.filter(AccessToken.expires_at < lastweek)
# Batch delete.
query.delete()
# Manually deleting each record, because batch deletes are an
# exception to ORM Cascade markup.
for token in query.all():
api_base.entity_hard_delete(AccessToken, token.id)

View File

@ -642,6 +642,100 @@ class TestOAuthAccessToken(BaseOAuthTest):
self.assertEqual(e_msg.INVALID_TOKEN_GRANT_TYPE,
response.json['error_description'])
def test_valid_refresh_token(self):
"""This test ensures that a valid refresh token can be converted into
a valid access token, and cleans up after itself.
"""
# Generate a valid access code
authorization_code = auth_api.authorization_code_save({
'user_id': 2,
'state': 'test_state',
'code': 'test_valid_code'
})
# Generate an auth and a refresh token.
resp_1 = self.app.post('/v1/openid/token',
params={
'code': authorization_code.code,
'grant_type': 'authorization_code'
},
content_type=
'application/x-www-form-urlencoded',
expect_errors=True)
# Assert that this is a successful response
self.assertEqual(200, resp_1.status_code)
# Assert that the token came back in the response
t1 = resp_1.json
# Assert that both are in the database.
access_token = \
token_api.access_token_get_by_token(t1['access_token'])
self.assertIsNotNone(access_token)
refresh_token = \
auth_api.refresh_token_get(t1['refresh_token'])
self.assertIsNotNone(refresh_token)
# Issue a refresh token request.
resp_2 = self.app.post('/v1/openid/token',
params={
'refresh_token': t1['refresh_token'],
'grant_type': 'refresh_token'
},
content_type=
'application/x-www-form-urlencoded',
expect_errors=True)
# Assert that the response is good.
self.assertEqual(200, resp_2.status_code)
# Assert that the token came back in the response
t2 = resp_2.json
self.assertIsNotNone(t2['access_token'])
self.assertIsNotNone(t2['expires_in'])
self.assertIsNotNone(t2['id_token'])
self.assertIsNotNone(t2['refresh_token'])
self.assertIsNotNone(t2['token_type'])
self.assertEqual('Bearer', t2['token_type'])
# Assert that the access token is in the database
new_access_token = \
token_api.access_token_get_by_token(t2['access_token'])
self.assertIsNotNone(new_access_token)
# Assert that system configured values is owned by the correct user.
self.assertEquals(2, new_access_token.user_id)
self.assertEquals(t2['id_token'], new_access_token.user_id)
self.assertEqual(t2['expires_in'], CONF.oauth.access_token_ttl)
self.assertEqual(t2['expires_in'], new_access_token.expires_in)
self.assertEqual(t2['access_token'],
new_access_token.access_token)
# Assert that the refresh token is in the database
new_refresh_token = \
auth_api.refresh_token_get(t2['refresh_token'])
self.assertIsNotNone(new_refresh_token)
# Assert that system configured values is owned by the correct user.
self.assertEquals(2, new_refresh_token.user_id)
self.assertEqual(CONF.oauth.refresh_token_ttl,
new_refresh_token.expires_in)
self.assertEqual(t2['refresh_token'],
new_refresh_token.refresh_token)
# Assert that the old access tokens are no longer in the database and
# have been cleaned up.
no_access_token = \
token_api.access_token_get_by_token(t1['access_token'])
no_refresh_token = \
auth_api.refresh_token_get(t1['refresh_token'])
self.assertIsNone(no_refresh_token)
self.assertIsNone(no_access_token)
def test_invalid_refresh_token(self):
"""This test ensures that an invalid refresh token can be converted
into a valid access token.

View File

@ -18,6 +18,7 @@ from datetime import timedelta
from oslo.config import cfg
import storyboard.db.api.base as db_api
from storyboard.db.models import AccessToken
from storyboard.db.models import RefreshToken
from storyboard.plugin.token_cleaner.cleaner import TokenCleaner
import storyboard.tests.base as base
from storyboard.tests.mock_data import load_data
@ -61,22 +62,38 @@ class TestTokenCleaner(base.FunctionalTest, base.WorkingDirTestCase):
# expiration dates. I subtract 5 seconds here because the time it
# takes to execute the script may, or may not, result in an
# 8-day-old-token to remain valid.
new_access_tokens = []
new_refresh_tokens = []
for i in range(0, 100):
created_at = datetime.utcnow() - timedelta(days=i)
expires_in = (60 * 60 * 24) - 5 # Minus five seconds, see above.
expires_at = created_at + timedelta(seconds=expires_in)
load_data([
new_access_tokens.append(
AccessToken(
user_id=1,
created_at=created_at.strftime('%Y-%m-%d %H:%M:%S'),
expires_in=expires_in,
expires_at=expires_at.strftime('%Y-%m-%d %H:%M:%S'),
access_token='test_token_%s' % (i,))
])
)
new_access_tokens = load_data(new_access_tokens)
for token in new_access_tokens:
new_refresh_tokens.append(
RefreshToken(
user_id=1,
access_token_id=token.id,
created_at=token.created_at,
expires_at=token.expires_at,
expires_in=300,
refresh_token='test_refresh_%s' % (token.id,))
)
new_refresh_tokens = load_data(new_refresh_tokens)
# Make sure we have 100 tokens.
self.assertEqual(100, db_api.model_query(AccessToken).count())
self.assertEqual(100, db_api.model_query(RefreshToken).count())
# Run the plugin.
plugin = TokenCleaner(CONF)
@ -84,3 +101,4 @@ class TestTokenCleaner(base.FunctionalTest, base.WorkingDirTestCase):
# Make sure we have 8 tokens left (since one plugin starts today).
self.assertEqual(8, db_api.model_query(AccessToken).count())
self.assertEqual(8, db_api.model_query(RefreshToken).count())