3
3
import json
4
4
import time
5
5
6
- from msal .token_cache import *
6
+ from msal .token_cache import TokenCache , SerializableTokenCache
7
7
from tests import unittest
8
8
9
9
@@ -51,11 +51,14 @@ class TokenCacheTestCase(unittest.TestCase):
51
51
52
52
def setUp (self ):
53
53
self .cache = TokenCache ()
54
+ self .at_key_maker = self .cache .key_makers [
55
+ TokenCache .CredentialType .ACCESS_TOKEN ]
54
56
55
57
def testAddByAad (self ):
56
58
client_id = "my_client_id"
57
59
id_token = build_id_token (
58
60
oid = "object1234" , preferred_username = "John Doe" , aud = client_id )
61
+ now = 1000
59
62
self .cache .add ({
60
63
"client_id" : client_id ,
61
64
"scope" : ["s2" , "s1" , "s3" ], # Not in particular order
@@ -64,7 +67,7 @@ def testAddByAad(self):
64
67
uid = "uid" , utid = "utid" , # client_info
65
68
expires_in = 3600 , access_token = "an access token" ,
66
69
id_token = id_token , refresh_token = "a refresh token" ),
67
- }, now = 1000 )
70
+ }, now = now )
68
71
access_token_entry = {
69
72
'cached_at' : "1000" ,
70
73
'client_id' : 'my_client_id' ,
@@ -78,14 +81,11 @@ def testAddByAad(self):
78
81
'target' : 's1 s2 s3' , # Sorted
79
82
'token_type' : 'some type' ,
80
83
}
81
- self .assertEqual (
82
- access_token_entry ,
83
- self .cache ._cache ["AccessToken" ].get (
84
- 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3' )
85
- )
84
+ self .assertEqual (access_token_entry , self .cache ._cache ["AccessToken" ].get (
85
+ self .at_key_maker (** access_token_entry )))
86
86
self .assertIn (
87
87
access_token_entry ,
88
- self .cache .find (self .cache .CredentialType .ACCESS_TOKEN ),
88
+ self .cache .find (self .cache .CredentialType .ACCESS_TOKEN , now = now ),
89
89
"find(..., query=None) should not crash, even though MSAL does not use it" )
90
90
self .assertEqual (
91
91
{
@@ -144,8 +144,7 @@ def testAddByAdfs(self):
144
144
expires_in = 3600 , access_token = "an access token" ,
145
145
id_token = id_token , refresh_token = "a refresh token" ),
146
146
}, now = 1000 )
147
- self .assertEqual (
148
- {
147
+ access_token_entry = {
149
148
'cached_at' : "1000" ,
150
149
'client_id' : 'my_client_id' ,
151
150
'credential_type' : 'AccessToken' ,
@@ -157,10 +156,9 @@ def testAddByAdfs(self):
157
156
'secret' : 'an access token' ,
158
157
'target' : 's1 s2 s3' , # Sorted
159
158
'token_type' : 'some type' ,
160
- },
161
- self .cache ._cache ["AccessToken" ].get (
162
- 'subject-fs.msidlab8.com-accesstoken-my_client_id-adfs-s1 s2 s3' )
163
- )
159
+ }
160
+ self .assertEqual (access_token_entry , self .cache ._cache ["AccessToken" ].get (
161
+ self .at_key_maker (** access_token_entry )))
164
162
self .assertEqual (
165
163
{
166
164
'client_id' : 'my_client_id' ,
@@ -206,37 +204,67 @@ def testAddByAdfs(self):
206
204
"appmetadata-fs.msidlab8.com-my_client_id" )
207
205
)
208
206
209
- def test_key_id_is_also_recorded (self ):
210
- my_key_id = "some_key_id_123"
207
+ def assertFoundAccessToken (self , * , scopes , query , data = None , now = None ):
208
+ cached_at = None
209
+ for cached_at in self .cache .search (
210
+ TokenCache .CredentialType .ACCESS_TOKEN ,
211
+ target = scopes , query = query , now = now ,
212
+ ):
213
+ for k , v in (data or {}).items (): # The extra data, if any
214
+ self .assertEqual (cached_at .get (k ), v , f"AT should contain { k } ={ v } " )
215
+ self .assertTrue (cached_at , "AT should be cached and searchable" )
216
+ return cached_at
217
+
218
+ def _test_data_should_be_saved_and_searchable_in_access_token (self , data ):
219
+ scopes = ["s2" , "s1" , "s3" ] # Not in particular order
220
+ now = 1000
211
221
self .cache .add ({
212
- "data" : { "key_id" : my_key_id } ,
222
+ "data" : data ,
213
223
"client_id" : "my_client_id" ,
214
- "scope" : [ "s2" , "s1" , "s3" ], # Not in particular order
224
+ "scope" : scopes ,
215
225
"token_endpoint" : "https://login.example.com/contoso/v2/token" ,
216
226
"response" : build_response (
217
227
uid = "uid" , utid = "utid" , # client_info
218
228
expires_in = 3600 , access_token = "an access token" ,
219
229
refresh_token = "a refresh token" ),
220
- }, now = 1000 )
221
- cached_key_id = self .cache ._cache ["AccessToken" ].get (
222
- 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3' ,
223
- {}).get ("key_id" )
224
- self .assertEqual (my_key_id , cached_key_id , "AT should be bound to the key" )
230
+ }, now = now )
231
+ self .assertFoundAccessToken (scopes = scopes , data = data , now = now , query = dict (
232
+ data , # Also use the extra data as a query criteria
233
+ client_id = "my_client_id" ,
234
+ environment = "login.example.com" ,
235
+ realm = "contoso" ,
236
+ home_account_id = "uid.utid" ,
237
+ ))
238
+
239
+ def test_extra_data_should_also_be_recorded_and_searchable_in_access_token (self ):
240
+ self ._test_data_should_be_saved_and_searchable_in_access_token ({"key_id" : "1" })
241
+
242
+ def test_access_tokens_with_different_key_id (self ):
243
+ self ._test_data_should_be_saved_and_searchable_in_access_token ({"key_id" : "1" })
244
+ self ._test_data_should_be_saved_and_searchable_in_access_token ({"key_id" : "2" })
245
+ self .assertEqual (
246
+ len (self .cache ._cache ["AccessToken" ]),
247
+ 1 , """Historically, tokens are not keyed by key_id,
248
+ so a new token overwrites the old one, and we would end up with 1 token in cache""" )
225
249
226
250
def test_refresh_in_should_be_recorded_as_refresh_on (self ): # Sounds weird. Yep.
251
+ scopes = ["s2" , "s1" , "s3" ] # Not in particular order
227
252
self .cache .add ({
228
253
"client_id" : "my_client_id" ,
229
- "scope" : [ "s2" , "s1" , "s3" ], # Not in particular order
254
+ "scope" : scopes ,
230
255
"token_endpoint" : "https://login.example.com/contoso/v2/token" ,
231
256
"response" : build_response (
232
257
uid = "uid" , utid = "utid" , # client_info
233
258
expires_in = 3600 , refresh_in = 1800 , access_token = "an access token" ,
234
259
), #refresh_token="a refresh token"),
235
260
}, now = 1000 )
236
- refresh_on = self .cache ._cache ["AccessToken" ].get (
237
- 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3' ,
238
- {}).get ("refresh_on" )
239
- self .assertEqual ("2800" , refresh_on , "Should save refresh_on" )
261
+ at = self .assertFoundAccessToken (scopes = scopes , query = dict (
262
+ client_id = "my_client_id" ,
263
+ environment = "login.example.com" ,
264
+ realm = "contoso" ,
265
+ home_account_id = "uid.utid" ,
266
+ ))
267
+ self .assertEqual ("2800" , at .get ("refresh_on" ), "Should save refresh_on" )
240
268
241
269
def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt (self ):
242
270
sample = {
@@ -258,7 +286,7 @@ def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt(self):
258
286
)
259
287
260
288
261
- class SerializableTokenCacheTestCase (TokenCacheTestCase ):
289
+ class SerializableTokenCacheTestCase (unittest . TestCase ):
262
290
# Run all inherited test methods, and have extra check in tearDown()
263
291
264
292
def setUp (self ):
0 commit comments