From e5974b34a120de0b95075767e24cd8f080d33768 Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Fri, 13 Sep 2024 12:34:56 -0700 Subject: [PATCH] Change AccessToken key_maker algorithm --- msal/token_cache.py | 14 ++++++---- tests/test_token_cache.py | 55 ++++++++++++++++----------------------- 2 files changed, 31 insertions(+), 38 deletions(-) diff --git a/msal/token_cache.py b/msal/token_cache.py index 136e38b8..f756ed8f 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -43,6 +43,8 @@ def __init__(self): self._lock = threading.RLock() self._cache = {} self.key_makers = { + # Note: We have changed token key format before when ordering scopes; + # changing token key won't result in cache miss. self.CredentialType.REFRESH_TOKEN: lambda home_account_id=None, environment=None, client_id=None, target=None, **ignored_payload_from_a_real_token: @@ -56,14 +58,18 @@ def __init__(self): ]).lower(), self.CredentialType.ACCESS_TOKEN: lambda home_account_id=None, environment=None, client_id=None, - realm=None, target=None, **ignored_payload_from_a_real_token: - "-".join([ + realm=None, target=None, + # Note: New field(s) can be added here + key_id=None, + **ignored_payload_from_a_real_token: + "-".join([ # Note: Could use a hash here to shorten key length home_account_id or "", environment or "", self.CredentialType.ACCESS_TOKEN, client_id or "", realm or "", target or "", + key_id or "", # So ATs of different key_id can coexist ]).lower(), self.CredentialType.ID_TOKEN: lambda home_account_id=None, environment=None, client_id=None, @@ -150,9 +156,7 @@ def search(self, credential_type, target=None, query=None): # O(n) generator target_set = set(target) with self._lock: - # Since the target inside token cache key is (per schema) unsorted, - # there is no point to attempt an O(1) key-value search here. - # So we always do an O(n) in-memory search. + # O(n) search. The key is NOT used in search. for entry in self._cache.get(credential_type, {}).values(): if (entry != preferred_result # Avoid yielding the same entry twice and self._is_matching(entry, query, target_set=target_set) diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py index 41547a0a..909cbcbc 100644 --- a/tests/test_token_cache.py +++ b/tests/test_token_cache.py @@ -3,7 +3,7 @@ import json import time -from msal.token_cache import * +from msal.token_cache import TokenCache, SerializableTokenCache from tests import unittest @@ -51,6 +51,8 @@ class TokenCacheTestCase(unittest.TestCase): def setUp(self): self.cache = TokenCache() + self.at_key_maker = self.cache.key_makers[ + TokenCache.CredentialType.ACCESS_TOKEN] def testAddByAad(self): client_id = "my_client_id" @@ -78,11 +80,8 @@ def testAddByAad(self): 'target': 's1 s2 s3', # Sorted 'token_type': 'some type', } - self.assertEqual( - access_token_entry, - self.cache._cache["AccessToken"].get( - 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3') - ) + self.assertEqual(access_token_entry, self.cache._cache["AccessToken"].get( + self.at_key_maker(**access_token_entry))) self.assertIn( access_token_entry, self.cache.find(self.cache.CredentialType.ACCESS_TOKEN), @@ -144,8 +143,7 @@ def testAddByAdfs(self): expires_in=3600, access_token="an access token", id_token=id_token, refresh_token="a refresh token"), }, now=1000) - self.assertEqual( - { + access_token_entry = { 'cached_at': "1000", 'client_id': 'my_client_id', 'credential_type': 'AccessToken', @@ -157,10 +155,9 @@ def testAddByAdfs(self): 'secret': 'an access token', 'target': 's1 s2 s3', # Sorted 'token_type': 'some type', - }, - self.cache._cache["AccessToken"].get( - 'subject-fs.msidlab8.com-accesstoken-my_client_id-adfs-s1 s2 s3') - ) + } + self.assertEqual(access_token_entry, self.cache._cache["AccessToken"].get( + self.at_key_maker(**access_token_entry))) self.assertEqual( { 'client_id': 'my_client_id', @@ -238,37 +235,29 @@ def _test_data_should_be_saved_and_searchable_in_access_token(self, data): def test_extra_data_should_also_be_recorded_and_searchable_in_access_token(self): self._test_data_should_be_saved_and_searchable_in_access_token({"key_id": "1"}) - def test_key_id_is_also_recorded(self): - my_key_id = "some_key_id_123" - self.cache.add({ - "data": {"key_id": my_key_id}, - "client_id": "my_client_id", - "scope": ["s2", "s1", "s3"], # Not in particular order - "token_endpoint": "https://login.example.com/contoso/v2/token", - "response": build_response( - uid="uid", utid="utid", # client_info - expires_in=3600, access_token="an access token", - refresh_token="a refresh token"), - }, now=1000) - cached_key_id = self.cache._cache["AccessToken"].get( - 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3', - {}).get("key_id") - self.assertEqual(my_key_id, cached_key_id, "AT should be bound to the key") + def test_access_tokens_with_different_key_id_should_coexist(self): + self._test_data_should_be_saved_and_searchable_in_access_token({"key_id": "1"}) + self._test_data_should_be_saved_and_searchable_in_access_token({"key_id": "2"}) + self.assertEqual(2, len(self.cache._cache["AccessToken"]), "Should have 2 ATs") def test_refresh_in_should_be_recorded_as_refresh_on(self): # Sounds weird. Yep. + scopes = ["s2", "s1", "s3"] # Not in particular order self.cache.add({ "client_id": "my_client_id", - "scope": ["s2", "s1", "s3"], # Not in particular order + "scope": scopes, "token_endpoint": "https://login.example.com/contoso/v2/token", "response": build_response( uid="uid", utid="utid", # client_info expires_in=3600, refresh_in=1800, access_token="an access token", ), #refresh_token="a refresh token"), }, now=1000) - refresh_on = self.cache._cache["AccessToken"].get( - 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3', - {}).get("refresh_on") - self.assertEqual("2800", refresh_on, "Should save refresh_on") + at = self.assertFoundAccessToken(scopes=scopes, query=dict( + client_id="my_client_id", + environment="login.example.com", + realm="contoso", + home_account_id="uid.utid", + )) + self.assertEqual("2800", at.get("refresh_on"), "Should save refresh_on") def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt(self): sample = {