Skip to content

Commit 70f0d42

Browse files
authored
SharedTokenCacheCredential lazily loads the cache (#12172)
1 parent b33b8ec commit 70f0d42

File tree

5 files changed

+99
-27
lines changed

5 files changed

+99
-27
lines changed

sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py

+3
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
5555
if not scopes:
5656
raise ValueError("'get_token' requires at least one scope")
5757

58+
if not self._initialized:
59+
self._initialize()
60+
5861
if not self._client:
5962
raise CredentialUnavailableError(message="Shared token cache unavailable")
6063

sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py

+16-17
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Licensed under the MIT License.
44
# ------------------------------------
55
import abc
6+
import platform
67
import time
78

89
from msal import TokenCache
@@ -107,20 +108,26 @@ def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument
107108
self._tenant_id = kwargs.pop("tenant_id", None)
108109

109110
self._cache = kwargs.pop("_cache", None)
110-
if not self._cache:
111-
allow_unencrypted = kwargs.pop("allow_unencrypted_cache", False)
111+
self._client = None # type: Optional[AadClientBase]
112+
self._client_kwargs = kwargs
113+
self._client_kwargs["tenant_id"] = authenticating_tenant
114+
self._initialized = False
115+
116+
def _initialize(self):
117+
if self._initialized:
118+
return
119+
120+
if not self._cache and self.supported():
121+
allow_unencrypted = self._client_kwargs.get("allow_unencrypted_cache", False)
112122
try:
113123
self._cache = load_user_cache(allow_unencrypted)
114124
except Exception: # pylint:disable=broad-except
115125
pass
116126

117127
if self._cache:
118-
self._client = self._get_auth_client(
119-
authority=self._authority, cache=self._cache, tenant_id=authenticating_tenant, **kwargs
120-
) # type: Optional[AadClientBase]
121-
else:
122-
# couldn't load the cache -> credential will be unavailable
123-
self._client = None
128+
self._client = self._get_auth_client(authority=self._authority, cache=self._cache, **self._client_kwargs)
129+
130+
self._initialized = True
124131

125132
@abc.abstractmethod
126133
def _get_auth_client(self, **kwargs):
@@ -236,12 +243,4 @@ def supported():
236243
237244
:rtype: bool
238245
"""
239-
try:
240-
load_user_cache(allow_unencrypted=False)
241-
except NotImplementedError:
242-
return False
243-
except ValueError:
244-
# cache is supported but can't be encrypted
245-
pass
246-
247-
return True
246+
return platform.system() in {"Darwin", "Linux", "Windows"}

sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py

+3
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py
6363
if not scopes:
6464
raise ValueError("'get_token' requires at least one scope")
6565

66+
if not self._initialized:
67+
self._initialize()
68+
6669
if not self._client:
6770
raise CredentialUnavailableError(message="Shared token cache unavailable")
6871

sdk/identity/azure-identity/tests/test_shared_cache_credential.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@
3131
from helpers import build_aad_response, build_id_token, mock_response, Request, validating_transport
3232

3333

34+
def test_supported():
35+
"""the cache is supported on Linux, macOS, Windows, so this should pass unless you're developing on e.g. FreeBSD"""
36+
assert SharedTokenCacheCredential.supported()
37+
38+
3439
def test_no_scopes():
3540
"""The credential should raise when get_token is called with no scopes"""
3641

@@ -717,14 +722,34 @@ def test_access_token_caching():
717722
)
718723

719724

725+
def test_initialization():
726+
"""the credential should attempt to load the cache only once, when it's first needed"""
727+
728+
with patch("azure.identity._internal.persistent_cache._load_persistent_cache") as mock_cache_loader:
729+
mock_cache_loader.side_effect = Exception("it didn't work")
730+
731+
credential = SharedTokenCacheCredential()
732+
assert mock_cache_loader.call_count == 0
733+
734+
for _ in range(2):
735+
with pytest.raises(CredentialUnavailableError):
736+
credential.get_token("scope")
737+
assert mock_cache_loader.call_count == 1
738+
739+
720740
def test_authentication_record_authenticating_tenant():
721741
"""when given a record and 'tenant_id', the credential should authenticate in the latter"""
722742

723743
expected_tenant_id = "tenant-id"
724744
record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...")
725745

726746
with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client:
727-
SharedTokenCacheCredential(authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id)
747+
credential = SharedTokenCacheCredential(
748+
authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id
749+
)
750+
with pytest.raises(CredentialUnavailableError):
751+
# this raises because the cache is empty
752+
credential.get_token("scope")
728753

729754
assert get_auth_client.call_count == 1
730755
_, kwargs = get_auth_client.call_args

sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py

+51-9
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@
2626
from test_shared_cache_credential import get_account_event, populated_cache
2727

2828

29+
def test_supported():
30+
"""the cache is supported on Linux, macOS, Windows, so this should pass unless you're developing on e.g. FreeBSD"""
31+
assert SharedTokenCacheCredential.supported()
32+
33+
2934
@pytest.mark.asyncio
3035
async def test_no_scopes():
3136
"""The credential should raise when get_token is called with no scopes"""
@@ -37,39 +42,53 @@ async def test_no_scopes():
3742

3843
@pytest.mark.asyncio
3944
async def test_close():
40-
transport = AsyncMockTransport()
45+
async def send(*_, **__):
46+
return mock_response(json_payload=build_aad_response(access_token="**"))
47+
48+
transport = AsyncMockTransport(send=send)
4149
credential = SharedTokenCacheCredential(
4250
_cache=populated_cache(get_account_event("test@user", "uid", "utid")), transport=transport
4351
)
4452

53+
# the credential doesn't open a transport session before one is needed, so we send a request
54+
await credential.get_token("scope")
55+
4556
await credential.close()
4657

4758
assert transport.__aexit__.call_count == 1
4859

4960

5061
@pytest.mark.asyncio
5162
async def test_context_manager():
52-
transport = AsyncMockTransport()
63+
async def send(*_, **__):
64+
return mock_response(json_payload=build_aad_response(access_token="**"))
65+
66+
transport = AsyncMockTransport(send=send)
5367
credential = SharedTokenCacheCredential(
5468
_cache=populated_cache(get_account_event("test@user", "uid", "utid")), transport=transport
5569
)
5670

71+
# async with before initialization: credential should call aexit but not aenter
5772
async with credential:
58-
assert transport.__aenter__.call_count == 1
73+
await credential.get_token("scope")
5974

60-
assert transport.__aenter__.call_count == 1
75+
assert transport.__aenter__.call_count == 0
6176
assert transport.__aexit__.call_count == 1
6277

78+
# async with after initialization: credential should call aenter and aexit
79+
async with credential:
80+
await credential.get_token("scope")
81+
assert transport.__aenter__.call_count == 1
82+
assert transport.__aexit__.call_count == 2
83+
6384

6485
@pytest.mark.asyncio
6586
async def test_context_manager_no_cache():
6687
"""the credential shouldn't open/close sessions when instantiated in an environment with no cache"""
6788

6889
transport = AsyncMockTransport()
6990

70-
with patch(
71-
"azure.identity._internal.shared_token_cache.load_user_cache", Mock(side_effect=NotImplementedError)
72-
):
91+
with patch("azure.identity._internal.shared_token_cache.load_user_cache", Mock(side_effect=NotImplementedError)):
7392
credential = SharedTokenCacheCredential(transport=transport)
7493

7594
async with credential:
@@ -666,14 +685,20 @@ async def test_auth_record_multiple_accounts_for_username():
666685
assert token.token == expected_access_token
667686

668687

669-
def test_authentication_record_authenticating_tenant():
688+
@pytest.mark.asyncio
689+
async def test_authentication_record_authenticating_tenant():
670690
"""when given a record and 'tenant_id', the credential should authenticate in the latter"""
671691

672692
expected_tenant_id = "tenant-id"
673693
record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...")
674694

675695
with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client:
676-
SharedTokenCacheCredential(authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id)
696+
credential = SharedTokenCacheCredential(
697+
authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id
698+
)
699+
with pytest.raises(CredentialUnavailableError):
700+
# this raises because the cache is empty
701+
await credential.get_token("scope")
677702

678703
assert get_auth_client.call_count == 1
679704
_, kwargs = get_auth_client.call_args
@@ -713,3 +738,20 @@ async def test_allow_unencrypted_cache():
713738

714739
msal_extensions_patch.stop()
715740
platform_patch.stop()
741+
742+
743+
@pytest.mark.asyncio
744+
async def test_initialization():
745+
"""the credential should attempt to load the cache only once, when it's first needed"""
746+
747+
with patch("azure.identity._internal.persistent_cache._load_persistent_cache") as mock_cache_loader:
748+
mock_cache_loader.side_effect = Exception("it didn't work")
749+
750+
credential = SharedTokenCacheCredential()
751+
assert mock_cache_loader.call_count == 0
752+
753+
for _ in range(2):
754+
with pytest.raises(CredentialUnavailableError):
755+
await credential.get_token("scope")
756+
assert mock_cache_loader.call_count == 1
757+

0 commit comments

Comments
 (0)