Skip to content

Commit 53833ca

Browse files
committed
Merge branch 'flexible-cache' into dev
2 parents 331c16f + 60144d5 commit 53833ca

File tree

3 files changed

+93
-41
lines changed

3 files changed

+93
-41
lines changed

msal/token_cache.py

+31-10
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def __init__(self):
4343
self._lock = threading.RLock()
4444
self._cache = {}
4545
self.key_makers = {
46+
# Note: We have changed token key format before when ordering scopes;
47+
# changing token key won't result in cache miss.
4648
self.CredentialType.REFRESH_TOKEN:
4749
lambda home_account_id=None, environment=None, client_id=None,
4850
target=None, **ignored_payload_from_a_real_token:
@@ -56,14 +58,18 @@ def __init__(self):
5658
]).lower(),
5759
self.CredentialType.ACCESS_TOKEN:
5860
lambda home_account_id=None, environment=None, client_id=None,
59-
realm=None, target=None, **ignored_payload_from_a_real_token:
60-
"-".join([
61+
realm=None, target=None,
62+
# Note: New field(s) can be added here
63+
#key_id=None,
64+
**ignored_payload_from_a_real_token:
65+
"-".join([ # Note: Could use a hash here to shorten key length
6166
home_account_id or "",
6267
environment or "",
6368
self.CredentialType.ACCESS_TOKEN,
6469
client_id or "",
6570
realm or "",
6671
target or "",
72+
#key_id or "", # So ATs of different key_id can coexist
6773
]).lower(),
6874
self.CredentialType.ID_TOKEN:
6975
lambda home_account_id=None, environment=None, client_id=None,
@@ -124,7 +130,7 @@ def _is_matching(entry: dict, query: dict, target_set: set = None) -> bool:
124130
target_set <= set(entry.get("target", "").split())
125131
if target_set else True)
126132

127-
def search(self, credential_type, target=None, query=None): # O(n) generator
133+
def search(self, credential_type, target=None, query=None, *, now=None): # O(n) generator
128134
"""Returns a generator of matching entries.
129135
130136
It is O(1) for AT hits, and O(n) for other types.
@@ -150,21 +156,33 @@ def search(self, credential_type, target=None, query=None): # O(n) generator
150156

151157
target_set = set(target)
152158
with self._lock:
153-
# Since the target inside token cache key is (per schema) unsorted,
154-
# there is no point to attempt an O(1) key-value search here.
155-
# So we always do an O(n) in-memory search.
159+
# O(n) search. The key is NOT used in search.
160+
now = int(time.time() if now is None else now)
161+
expired_access_tokens = [
162+
# Especially when/if we key ATs by ephemeral fields such as key_id,
163+
# stale ATs keyed by an old key_id would stay forever.
164+
# Here we collect them for their removal.
165+
]
156166
for entry in self._cache.get(credential_type, {}).values():
167+
if ( # Automatically delete expired access tokens
168+
credential_type == self.CredentialType.ACCESS_TOKEN
169+
and int(entry["expires_on"]) < now
170+
):
171+
expired_access_tokens.append(entry) # Can't delete them within current for-loop
172+
continue
157173
if (entry != preferred_result # Avoid yielding the same entry twice
158174
and self._is_matching(entry, query, target_set=target_set)
159175
):
160176
yield entry
177+
for at in expired_access_tokens:
178+
self.remove_at(at)
161179

162-
def find(self, credential_type, target=None, query=None):
180+
def find(self, credential_type, target=None, query=None, *, now=None):
163181
"""Equivalent to list(search(...))."""
164182
warnings.warn(
165183
"Use list(search(...)) instead to explicitly get a list.",
166184
DeprecationWarning)
167-
return list(self.search(credential_type, target=target, query=query))
185+
return list(self.search(credential_type, target=target, query=query, now=now))
168186

169187
def add(self, event, now=None):
170188
"""Handle a token obtaining event, and add tokens into cache."""
@@ -249,8 +267,11 @@ def __add(self, event, now=None):
249267
"expires_on": str(now + expires_in), # Same here
250268
"extended_expires_on": str(now + ext_expires_in) # Same here
251269
}
252-
if data.get("key_id"): # It happens in SSH-cert or POP scenario
253-
at["key_id"] = data.get("key_id")
270+
at.update({k: data[k] for k in data if k in {
271+
# Also store extra data which we explicitly allow
272+
# So that we won't accidentally store a user's password etc.
273+
"key_id", # It happens in SSH-cert or POP scenario
274+
}})
254275
if "refresh_in" in response:
255276
refresh_in = response["refresh_in"] # It is an integer
256277
at["refresh_on"] = str(now + refresh_in) # Schema wants a string

tests/test_application.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ class TestApplicationForRefreshInBehaviors(unittest.TestCase):
340340
account = {"home_account_id": "{}.{}".format(uid, utid)}
341341
rt = "this is a rt"
342342
client_id = "my_app"
343+
soon = 60 # application.py considers tokens within 5 minutes as expired
343344

344345
@classmethod
345346
def setUpClass(cls): # Initialization at runtime, not interpret-time
@@ -414,7 +415,8 @@ def mock_post(url, headers=None, *args, **kwargs):
414415

415416
def test_expired_token_and_unavailable_aad_should_return_error(self):
416417
# a.k.a. Attempt refresh expired token when AAD unavailable
417-
self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900)
418+
self.populate_cache(
419+
access_token="expired at", expires_in=self.soon, refresh_in=-900)
418420
error = "something went wrong"
419421
def mock_post(url, headers=None, *args, **kwargs):
420422
self.assertEqual("4|84,3|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
@@ -425,7 +427,8 @@ def mock_post(url, headers=None, *args, **kwargs):
425427

426428
def test_expired_token_and_available_aad_should_return_new_token(self):
427429
# a.k.a. Attempt refresh expired token when AAD available
428-
self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900)
430+
self.populate_cache(
431+
access_token="expired at", expires_in=self.soon, refresh_in=-900)
429432
new_access_token = "new AT"
430433
new_refresh_in = 123
431434
def mock_post(url, headers=None, *args, **kwargs):

tests/test_token_cache.py

+57-29
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44
import time
55

6-
from msal.token_cache import *
6+
from msal.token_cache import TokenCache, SerializableTokenCache
77
from tests import unittest
88

99

@@ -51,11 +51,14 @@ class TokenCacheTestCase(unittest.TestCase):
5151

5252
def setUp(self):
5353
self.cache = TokenCache()
54+
self.at_key_maker = self.cache.key_makers[
55+
TokenCache.CredentialType.ACCESS_TOKEN]
5456

5557
def testAddByAad(self):
5658
client_id = "my_client_id"
5759
id_token = build_id_token(
5860
oid="object1234", preferred_username="John Doe", aud=client_id)
61+
now = 1000
5962
self.cache.add({
6063
"client_id": client_id,
6164
"scope": ["s2", "s1", "s3"], # Not in particular order
@@ -64,7 +67,7 @@ def testAddByAad(self):
6467
uid="uid", utid="utid", # client_info
6568
expires_in=3600, access_token="an access token",
6669
id_token=id_token, refresh_token="a refresh token"),
67-
}, now=1000)
70+
}, now=now)
6871
access_token_entry = {
6972
'cached_at': "1000",
7073
'client_id': 'my_client_id',
@@ -78,14 +81,11 @@ def testAddByAad(self):
7881
'target': 's1 s2 s3', # Sorted
7982
'token_type': 'some type',
8083
}
81-
self.assertEqual(
82-
access_token_entry,
83-
self.cache._cache["AccessToken"].get(
84-
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3')
85-
)
84+
self.assertEqual(access_token_entry, self.cache._cache["AccessToken"].get(
85+
self.at_key_maker(**access_token_entry)))
8686
self.assertIn(
8787
access_token_entry,
88-
self.cache.find(self.cache.CredentialType.ACCESS_TOKEN),
88+
self.cache.find(self.cache.CredentialType.ACCESS_TOKEN, now=now),
8989
"find(..., query=None) should not crash, even though MSAL does not use it")
9090
self.assertEqual(
9191
{
@@ -144,8 +144,7 @@ def testAddByAdfs(self):
144144
expires_in=3600, access_token="an access token",
145145
id_token=id_token, refresh_token="a refresh token"),
146146
}, now=1000)
147-
self.assertEqual(
148-
{
147+
access_token_entry = {
149148
'cached_at': "1000",
150149
'client_id': 'my_client_id',
151150
'credential_type': 'AccessToken',
@@ -157,10 +156,9 @@ def testAddByAdfs(self):
157156
'secret': 'an access token',
158157
'target': 's1 s2 s3', # Sorted
159158
'token_type': 'some type',
160-
},
161-
self.cache._cache["AccessToken"].get(
162-
'subject-fs.msidlab8.com-accesstoken-my_client_id-adfs-s1 s2 s3')
163-
)
159+
}
160+
self.assertEqual(access_token_entry, self.cache._cache["AccessToken"].get(
161+
self.at_key_maker(**access_token_entry)))
164162
self.assertEqual(
165163
{
166164
'client_id': 'my_client_id',
@@ -206,37 +204,67 @@ def testAddByAdfs(self):
206204
"appmetadata-fs.msidlab8.com-my_client_id")
207205
)
208206

209-
def test_key_id_is_also_recorded(self):
210-
my_key_id = "some_key_id_123"
207+
def assertFoundAccessToken(self, *, scopes, query, data=None, now=None):
208+
cached_at = None
209+
for cached_at in self.cache.search(
210+
TokenCache.CredentialType.ACCESS_TOKEN,
211+
target=scopes, query=query, now=now,
212+
):
213+
for k, v in (data or {}).items(): # The extra data, if any
214+
self.assertEqual(cached_at.get(k), v, f"AT should contain {k}={v}")
215+
self.assertTrue(cached_at, "AT should be cached and searchable")
216+
return cached_at
217+
218+
def _test_data_should_be_saved_and_searchable_in_access_token(self, data):
219+
scopes = ["s2", "s1", "s3"] # Not in particular order
220+
now = 1000
211221
self.cache.add({
212-
"data": {"key_id": my_key_id},
222+
"data": data,
213223
"client_id": "my_client_id",
214-
"scope": ["s2", "s1", "s3"], # Not in particular order
224+
"scope": scopes,
215225
"token_endpoint": "https://login.example.com/contoso/v2/token",
216226
"response": build_response(
217227
uid="uid", utid="utid", # client_info
218228
expires_in=3600, access_token="an access token",
219229
refresh_token="a refresh token"),
220-
}, now=1000)
221-
cached_key_id = self.cache._cache["AccessToken"].get(
222-
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3',
223-
{}).get("key_id")
224-
self.assertEqual(my_key_id, cached_key_id, "AT should be bound to the key")
230+
}, now=now)
231+
self.assertFoundAccessToken(scopes=scopes, data=data, now=now, query=dict(
232+
data, # Also use the extra data as a query criteria
233+
client_id="my_client_id",
234+
environment="login.example.com",
235+
realm="contoso",
236+
home_account_id="uid.utid",
237+
))
238+
239+
def test_extra_data_should_also_be_recorded_and_searchable_in_access_token(self):
240+
self._test_data_should_be_saved_and_searchable_in_access_token({"key_id": "1"})
241+
242+
def test_access_tokens_with_different_key_id(self):
243+
self._test_data_should_be_saved_and_searchable_in_access_token({"key_id": "1"})
244+
self._test_data_should_be_saved_and_searchable_in_access_token({"key_id": "2"})
245+
self.assertEqual(
246+
len(self.cache._cache["AccessToken"]),
247+
1, """Historically, tokens are not keyed by key_id,
248+
so a new token overwrites the old one, and we would end up with 1 token in cache""")
225249

226250
def test_refresh_in_should_be_recorded_as_refresh_on(self): # Sounds weird. Yep.
251+
scopes = ["s2", "s1", "s3"] # Not in particular order
227252
self.cache.add({
228253
"client_id": "my_client_id",
229-
"scope": ["s2", "s1", "s3"], # Not in particular order
254+
"scope": scopes,
230255
"token_endpoint": "https://login.example.com/contoso/v2/token",
231256
"response": build_response(
232257
uid="uid", utid="utid", # client_info
233258
expires_in=3600, refresh_in=1800, access_token="an access token",
234259
), #refresh_token="a refresh token"),
235260
}, now=1000)
236-
refresh_on = self.cache._cache["AccessToken"].get(
237-
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3',
238-
{}).get("refresh_on")
239-
self.assertEqual("2800", refresh_on, "Should save refresh_on")
261+
at = self.assertFoundAccessToken(scopes=scopes, query=dict(
262+
client_id="my_client_id",
263+
environment="login.example.com",
264+
realm="contoso",
265+
home_account_id="uid.utid",
266+
))
267+
self.assertEqual("2800", at.get("refresh_on"), "Should save refresh_on")
240268

241269
def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt(self):
242270
sample = {
@@ -258,7 +286,7 @@ def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt(self):
258286
)
259287

260288

261-
class SerializableTokenCacheTestCase(TokenCacheTestCase):
289+
class SerializableTokenCacheTestCase(unittest.TestCase):
262290
# Run all inherited test methods, and have extra check in tearDown()
263291

264292
def setUp(self):

0 commit comments

Comments
 (0)