Skip to content

Commit

Permalink
Change AccessToken key_maker algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
rayluo committed Sep 23, 2024
1 parent ede849e commit e5974b3
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 38 deletions.
14 changes: 9 additions & 5 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def __init__(self):
self._lock = threading.RLock()
self._cache = {}
self.key_makers = {
# Note: We have changed token key format before when ordering scopes;
# changing token key won't result in cache miss.
self.CredentialType.REFRESH_TOKEN:
lambda home_account_id=None, environment=None, client_id=None,
target=None, **ignored_payload_from_a_real_token:
Expand All @@ -56,14 +58,18 @@ def __init__(self):
]).lower(),
self.CredentialType.ACCESS_TOKEN:
lambda home_account_id=None, environment=None, client_id=None,
realm=None, target=None, **ignored_payload_from_a_real_token:
"-".join([
realm=None, target=None,
# Note: New field(s) can be added here
key_id=None,
**ignored_payload_from_a_real_token:
"-".join([ # Note: Could use a hash here to shorten key length
home_account_id or "",
environment or "",
self.CredentialType.ACCESS_TOKEN,
client_id or "",
realm or "",
target or "",
key_id or "", # So ATs of different key_id can coexist
]).lower(),
self.CredentialType.ID_TOKEN:
lambda home_account_id=None, environment=None, client_id=None,
Expand Down Expand Up @@ -150,9 +156,7 @@ def search(self, credential_type, target=None, query=None): # O(n) generator

target_set = set(target)
with self._lock:
# Since the target inside token cache key is (per schema) unsorted,
# there is no point to attempt an O(1) key-value search here.
# So we always do an O(n) in-memory search.
# O(n) search. The key is NOT used in search.
for entry in self._cache.get(credential_type, {}).values():
if (entry != preferred_result # Avoid yielding the same entry twice
and self._is_matching(entry, query, target_set=target_set)
Expand Down
55 changes: 22 additions & 33 deletions tests/test_token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import time

from msal.token_cache import *
from msal.token_cache import TokenCache, SerializableTokenCache
from tests import unittest


Expand Down Expand Up @@ -51,6 +51,8 @@ class TokenCacheTestCase(unittest.TestCase):

def setUp(self):
self.cache = TokenCache()
self.at_key_maker = self.cache.key_makers[
TokenCache.CredentialType.ACCESS_TOKEN]

def testAddByAad(self):
client_id = "my_client_id"
Expand Down Expand Up @@ -78,11 +80,8 @@ def testAddByAad(self):
'target': 's1 s2 s3', # Sorted
'token_type': 'some type',
}
self.assertEqual(
access_token_entry,
self.cache._cache["AccessToken"].get(
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3')
)
self.assertEqual(access_token_entry, self.cache._cache["AccessToken"].get(
self.at_key_maker(**access_token_entry)))
self.assertIn(
access_token_entry,
self.cache.find(self.cache.CredentialType.ACCESS_TOKEN),
Expand Down Expand Up @@ -144,8 +143,7 @@ def testAddByAdfs(self):
expires_in=3600, access_token="an access token",
id_token=id_token, refresh_token="a refresh token"),
}, now=1000)
self.assertEqual(
{
access_token_entry = {
'cached_at': "1000",
'client_id': 'my_client_id',
'credential_type': 'AccessToken',
Expand All @@ -157,10 +155,9 @@ def testAddByAdfs(self):
'secret': 'an access token',
'target': 's1 s2 s3', # Sorted
'token_type': 'some type',
},
self.cache._cache["AccessToken"].get(
'subject-fs.msidlab8.com-accesstoken-my_client_id-adfs-s1 s2 s3')
)
}
self.assertEqual(access_token_entry, self.cache._cache["AccessToken"].get(
self.at_key_maker(**access_token_entry)))
self.assertEqual(
{
'client_id': 'my_client_id',
Expand Down Expand Up @@ -238,37 +235,29 @@ def _test_data_should_be_saved_and_searchable_in_access_token(self, data):
def test_extra_data_should_also_be_recorded_and_searchable_in_access_token(self):
self._test_data_should_be_saved_and_searchable_in_access_token({"key_id": "1"})

def test_key_id_is_also_recorded(self):
my_key_id = "some_key_id_123"
self.cache.add({
"data": {"key_id": my_key_id},
"client_id": "my_client_id",
"scope": ["s2", "s1", "s3"], # Not in particular order
"token_endpoint": "https://login.example.com/contoso/v2/token",
"response": build_response(
uid="uid", utid="utid", # client_info
expires_in=3600, access_token="an access token",
refresh_token="a refresh token"),
}, now=1000)
cached_key_id = self.cache._cache["AccessToken"].get(
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3',
{}).get("key_id")
self.assertEqual(my_key_id, cached_key_id, "AT should be bound to the key")
def test_access_tokens_with_different_key_id_should_coexist(self):
self._test_data_should_be_saved_and_searchable_in_access_token({"key_id": "1"})
self._test_data_should_be_saved_and_searchable_in_access_token({"key_id": "2"})
self.assertEqual(2, len(self.cache._cache["AccessToken"]), "Should have 2 ATs")

def test_refresh_in_should_be_recorded_as_refresh_on(self): # Sounds weird. Yep.
scopes = ["s2", "s1", "s3"] # Not in particular order
self.cache.add({
"client_id": "my_client_id",
"scope": ["s2", "s1", "s3"], # Not in particular order
"scope": scopes,
"token_endpoint": "https://login.example.com/contoso/v2/token",
"response": build_response(
uid="uid", utid="utid", # client_info
expires_in=3600, refresh_in=1800, access_token="an access token",
), #refresh_token="a refresh token"),
}, now=1000)
refresh_on = self.cache._cache["AccessToken"].get(
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3',
{}).get("refresh_on")
self.assertEqual("2800", refresh_on, "Should save refresh_on")
at = self.assertFoundAccessToken(scopes=scopes, query=dict(
client_id="my_client_id",
environment="login.example.com",
realm="contoso",
home_account_id="uid.utid",
))
self.assertEqual("2800", at.get("refresh_on"), "Should save refresh_on")

def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt(self):
sample = {
Expand Down

0 comments on commit e5974b3

Please sign in to comment.