diff --git a/adal/cache_driver.py b/adal/cache_driver.py index 9683dca4..788cbe99 100644 --- a/adal/cache_driver.py +++ b/adal/cache_driver.py @@ -132,6 +132,12 @@ def _create_entry_from_refresh(self, entry, refresh_response): new_entry = copy.deepcopy(entry) new_entry.update(refresh_response) + # It is possible the response payload has no 'resource' field, like in ADFS, so we manually + # fill it here. Note, 'resource' is part of the token cache key, so we have to set it to avoid + # corrupting the cache. + if 'resource' not in refresh_response: + new_entry['resource'] = self._resource + if entry[TokenResponseFields.IS_MRRT] and self._authority != entry[TokenResponseFields._AUTHORITY]: new_entry[TokenResponseFields._AUTHORITY] = self._authority diff --git a/tests/test_refresh_token.py b/tests/test_refresh_token.py index 3e19e069..4c958f62 100644 --- a/tests/test_refresh_token.py +++ b/tests/test_refresh_token.py @@ -41,31 +41,61 @@ import mock import adal -from adal.authentication_context import AuthenticationContext +from adal.authentication_context import AuthenticationContext, TokenCache from tests import util from tests.util import parameters as cp class TestRefreshToken(unittest.TestCase): - def setUp(self): - self.response_options = { 'refreshedRefresh' : True } - self.response = util.create_response(self.response_options) - self.wire_response = self.response['wireResponse'] @httpretty.activate def test_happy_path_with_resource_client_secret(self): - tokenRequest = util.setup_expected_refresh_token_request_response(200, self.wire_response, self.response['authority'], self.response['resource'], cp['clientSecret']) + response_options = { 'refreshedRefresh' : True } + response = util.create_response(response_options) + wire_response = response['wireResponse'] + tokenRequest = util.setup_expected_refresh_token_request_response(200, wire_response, response['authority'], response['resource'], cp['clientSecret']) context = adal.AuthenticationContext(cp['authorityTenant']) def side_effect (tokenfunc): - return self.response['decodedResponse'] + return response['decodedResponse'] context._acquire_token = mock.MagicMock(side_effect=side_effect) token_response = context.acquire_token_with_refresh_token(cp['refreshToken'], cp['clientId'], cp['clientSecret'], cp['resource']) self.assertTrue( - util.is_match_token_response(self.response['decodedResponse'], token_response), + util.is_match_token_response(response['decodedResponse'], token_response), 'The response did not match what was expected: ' + str(token_response) ) + @httpretty.activate + def test_happy_path_with_resource_adfs(self): + # arrange + # set up token refresh result + wire_response = util.create_response({ + 'refreshedRefresh' : True, + 'mrrt': False + })['wireResponse'] + new_resource = 'https://graph.local.azurestack.external/' + tokenRequest = util.setup_expected_refresh_token_request_response(200, wire_response, cp['authority'], new_resource) + + # set up an existing token to be used for refreshing + existing_token = util.create_response({ + 'refreshedRefresh': True, + 'mrrt': True + })['decodedResponse'] + existing_token['_clientId'] = existing_token.get('_clientId') or cp['clientId'] + existing_token['isMRRT'] = existing_token.get('isMRRT') or True + existing_token['_authority'] = existing_token.get('_authority') or cp['authorizeUrl'] + token_cache = TokenCache(json.dumps([existing_token])) + + # act + user_id = existing_token['userId'] + context = adal.AuthenticationContext(cp['authorityTenant'], cache=token_cache) + token_response = context.acquire_token(new_resource, user_id, cp['clientId']) + + # assert + tokens = [value for key, value in token_cache.read_items()] + self.assertEqual(2, len(tokens)) + self.assertEqual({cp['resource'], new_resource}, set([x['resource'] for x in tokens])) + if __name__ == '__main__': unittest.main()