Skip to content

Commit b15cede

Browse files
authored
Persistent caching for service principal credentials (#11824)
1 parent 6a0e027 commit b15cede

15 files changed

+459
-23
lines changed

sdk/identity/azure-identity/CHANGELOG.md

+9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
# Release History
22

33
## 1.4.0b4 (Unreleased)
4+
- `CertificateCredential` and `ClientSecretCredential` can optionally store
5+
access tokens they acquire in a persistent cache. To enable this, construct
6+
the credential with `enable_persistent_cache=True`. On Linux, the persistent
7+
cache requires libsecret and `pygobject`. If these are unavailable or
8+
unusable (e.g. in an SSH session), loading the persistent cache will raise an
9+
error. You may optionally configure the credential to fall back to an
10+
unencrypted cache by constructing it with keyword argument
11+
`allow_unencrypted_cache=True`.
12+
([#11347](https://github.com/Azure/azure-sdk-for-python/issues/11347))
413
- `AzureCliCredential` raises `CredentialUnavailableError` when no user is
514
logged in to the Azure CLI.
615
([#11819](https://github.com/Azure/azure-sdk-for-python/issues/11819))

sdk/identity/azure-identity/azure/identity/_credentials/certificate.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ class CertificateCredential(CertificateCredentialBase):
2424
:keyword password: The certificate's password. If a unicode string, it will be encoded as UTF-8. If the certificate
2525
requires a different encoding, pass appropriately encoded bytes instead.
2626
:paramtype password: str or bytes
27+
:keyword bool enable_persistent_cache: if True, the credential will store tokens in a persistent cache. Defaults to
28+
False.
29+
:keyword bool allow_unencrypted_cache: if True, the credential will fall back to a plaintext cache when encryption
30+
is unavailable. Default to False. Has no effect when `enable_persistent_cache` is False.
2731
"""
2832

2933
def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
@@ -41,7 +45,7 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
4145
if not scopes:
4246
raise ValueError("'get_token' requires at least one scope")
4347

44-
token = self._client.get_cached_access_token(scopes)
48+
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
4549
if not token:
4650
token = self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
4751
return token

sdk/identity/azure-identity/azure/identity/_credentials/client_secret.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ class ClientSecretCredential(ClientSecretCredentialBase):
2525
:keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com',
2626
the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.KnownAuthorities`
2727
defines authorities for other clouds.
28+
:keyword bool enable_persistent_cache: if True, the credential will store tokens in a persistent cache. Defaults to
29+
False.
30+
:keyword bool allow_unencrypted_cache: if True, the credential will fall back to a plaintext cache when encryption
31+
is unavailable. Default to False. Has no effect when `enable_persistent_cache` is False.
2832
"""
2933

3034
def get_token(self, *scopes, **kwargs):
@@ -42,7 +46,7 @@ def get_token(self, *scopes, **kwargs):
4246
if not scopes:
4347
raise ValueError("'get_token' requires at least one scope")
4448

45-
token = self._client.get_cached_access_token(scopes)
49+
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
4650
if not token:
4751
token = self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
4852
return token

sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ def __init__(self, tenant_id, client_id, authority=None, cache=None, **kwargs):
4949
self._client_id = client_id
5050
self._pipeline = self._build_pipeline(**kwargs)
5151

52-
def get_cached_access_token(self, scopes):
53-
# type: (Sequence[str]) -> Optional[AccessToken]
54-
tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes))
52+
def get_cached_access_token(self, scopes, query=None):
53+
# type: (Sequence[str], Optional[dict]) -> Optional[AccessToken]
54+
tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes), query=query)
5555
for token in tokens:
5656
expires_on = int(token["expires_on"])
5757
if expires_on - 300 > int(time.time()):

sdk/identity/azure-identity/azure/identity/_internal/certificate_credential_base.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
# ------------------------------------
55
import abc
66

7+
from msal import TokenCache
78
import six
8-
from azure.identity._internal import AadClientCertificate
9+
10+
from . import AadClientCertificate
11+
from .persistent_cache import load_service_principal_cache
912

1013
try:
1114
ABC = abc.ABC
@@ -40,7 +43,16 @@ def __init__(self, tenant_id, client_id, certificate_path, **kwargs):
4043
pem_bytes = f.read()
4144

4245
self._certificate = AadClientCertificate(pem_bytes, password=password)
43-
self._client = self._get_auth_client(tenant_id, client_id, **kwargs)
46+
47+
enable_persistent_cache = kwargs.pop("enable_persistent_cache", False)
48+
if enable_persistent_cache:
49+
allow_unencrypted = kwargs.pop("allow_unencrypted_cache", False)
50+
cache = load_service_principal_cache(allow_unencrypted)
51+
else:
52+
cache = TokenCache()
53+
54+
self._client = self._get_auth_client(tenant_id, client_id, cache=cache, **kwargs)
55+
self._client_id = client_id
4456

4557
@abc.abstractmethod
4658
def _get_auth_client(self, tenant_id, client_id, **kwargs):

sdk/identity/azure-identity/azure/identity/_internal/client_secret_credential_base.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
import abc
66
from typing import TYPE_CHECKING
77

8+
from msal import TokenCache
9+
10+
from .persistent_cache import load_service_principal_cache
11+
812
try:
913
ABC = abc.ABC
1014
except AttributeError: # Python 2.7
@@ -27,7 +31,15 @@ def __init__(self, tenant_id, client_id, client_secret, **kwargs):
2731
"tenant_id should be an Azure Active Directory tenant's id (also called its 'directory id')"
2832
)
2933

30-
self._client = self._get_auth_client(tenant_id, client_id, **kwargs)
34+
enable_persistent_cache = kwargs.pop("enable_persistent_cache", False)
35+
if enable_persistent_cache:
36+
allow_unencrypted = kwargs.pop("allow_unencrypted_cache", False)
37+
cache = load_service_principal_cache(allow_unencrypted)
38+
else:
39+
cache = TokenCache()
40+
41+
self._client = self._get_auth_client(tenant_id, client_id, cache=cache, **kwargs)
42+
self._client_id = client_id
3143
self._secret = client_secret
3244

3345
@abc.abstractmethod

sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from .exception_wrapper import wrap_exceptions
2020
from .msal_transport_adapter import MsalTransportAdapter
21-
from .persistent_cache import load_persistent_cache
21+
from .persistent_cache import load_user_cache
2222
from .._constants import KnownAuthorities
2323
from .._exceptions import AuthenticationRequiredError, CredentialUnavailableError
2424
from .._internal import get_default_authority, normalize_authority
@@ -98,7 +98,7 @@ def __init__(self, client_id, client_credential=None, **kwargs):
9898
if not self._cache:
9999
if kwargs.pop("enable_persistent_cache", False):
100100
allow_unencrypted = kwargs.pop("allow_unencrypted_cache", False)
101-
self._cache = load_persistent_cache(allow_unencrypted)
101+
self._cache = load_user_cache(allow_unencrypted)
102102
else:
103103
self._cache = msal.TokenCache()
104104

sdk/identity/azure-identity/azure/identity/_internal/persistent_cache.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,18 @@
1313
import msal
1414

1515

16-
def load_persistent_cache(allow_unencrypted):
16+
def load_service_principal_cache(allow_unencrypted):
1717
# type: (Optional[bool]) -> msal.TokenCache
18+
return _load_persistent_cache(allow_unencrypted, "MSALConfidentialCache", "msal.confidential.cache")
19+
20+
21+
def load_user_cache(allow_unencrypted):
22+
# type: (Optional[bool]) -> msal.TokenCache
23+
return _load_persistent_cache(allow_unencrypted, "MSALCache", "msal.cache")
24+
25+
26+
def _load_persistent_cache(allow_unencrypted, account_name, cache_name):
27+
# type: (Optional[bool], str, str) -> msal.TokenCache
1828
"""Load the persistent cache using msal_extensions.
1929
2030
On Windows the cache is a file protected by the Data Protection API. On Linux and macOS the cache is stored by
@@ -26,19 +36,21 @@ def load_persistent_cache(allow_unencrypted):
2636
"""
2737

2838
if sys.platform.startswith("win") and "LOCALAPPDATA" in os.environ:
29-
cache_location = os.path.join(os.environ["LOCALAPPDATA"], ".IdentityService", "msal.cache")
39+
cache_location = os.path.join(os.environ["LOCALAPPDATA"], ".IdentityService", cache_name)
3040
persistence = msal_extensions.FilePersistenceWithDataProtection(cache_location)
3141
elif sys.platform.startswith("darwin"):
3242
# the cache uses this file's modified timestamp to decide whether to reload
33-
file_path = os.path.expanduser(os.path.join("~", ".IdentityService", "msal.cache"))
34-
persistence = msal_extensions.KeychainPersistence(file_path, "Microsoft.Developer.IdentityService", "MSALCache")
43+
file_path = os.path.expanduser(os.path.join("~", ".IdentityService", cache_name))
44+
persistence = msal_extensions.KeychainPersistence(
45+
file_path, "Microsoft.Developer.IdentityService", account_name
46+
)
3547
elif sys.platform.startswith("linux"):
3648
# The cache uses this file's modified timestamp to decide whether to reload. Note this path is the same
3749
# as that of the plaintext fallback: a new encrypted cache will stomp an unencrypted cache.
38-
file_path = os.path.expanduser(os.path.join("~", ".IdentityService", "msal.cache"))
50+
file_path = os.path.expanduser(os.path.join("~", ".IdentityService", cache_name))
3951
try:
4052
persistence = msal_extensions.LibsecretPersistence(
41-
file_path, "msal.cache", {"MsalClientID": "Microsoft.Developer.IdentityService"}, label="MSALCache"
53+
file_path, cache_name, {"MsalClientID": "Microsoft.Developer.IdentityService"}, label=account_name
4254
)
4355
except ImportError:
4456
if not allow_unencrypted:

sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py
5151
if not scopes:
5252
raise ValueError("'get_token' requires at least one scope")
5353

54-
token = self._client.get_cached_access_token(scopes)
54+
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
5555
if not token:
5656
token = await self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
5757
return token

sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ class ClientSecretCredential(AsyncCredentialBase, ClientSecretCredentialBase):
2323
:keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com',
2424
the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.KnownAuthorities`
2525
defines authorities for other clouds.
26+
:keyword bool enable_persistent_cache: if True, the credential will store tokens in a persistent cache. Defaults to
27+
False.
28+
:keyword bool allow_unencrypted_cache: if True, the credential will fall back to a plaintext cache when encryption
29+
is unavailable. Default to False. Has no effect when `enable_persistent_cache` is False.
2630
"""
2731

2832
async def __aenter__(self):
@@ -48,7 +52,7 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
4852
if not scopes:
4953
raise ValueError("'get_token' requires at least one scope")
5054

51-
token = self._client.get_cached_access_token(scopes)
55+
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
5256
if not token:
5357
token = await self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
5458
return token

sdk/identity/azure-identity/tests/test_certificate_credential.py

+100-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from cryptography.hazmat.backends import default_backend
1414
from cryptography.hazmat.primitives import hashes
1515
from cryptography.hazmat.primitives.asymmetric import padding
16+
from msal import TokenCache
1617
import pytest
1718
from six.moves.urllib_parse import urlparse
1819

@@ -135,9 +136,107 @@ def validate_jwt(request, client_id, pem_bytes):
135136
deserialized_header = json.loads(header.decode("utf-8"))
136137
assert deserialized_header["alg"] == "RS256"
137138
assert deserialized_header["typ"] == "JWT"
138-
assert urlsafeb64_decode(deserialized_header["x5t"]) == cert.fingerprint(hashes.SHA1()) #nosec
139+
assert urlsafeb64_decode(deserialized_header["x5t"]) == cert.fingerprint(hashes.SHA1()) # nosec
139140

140141
assert claims["aud"] == request.url
141142
assert claims["iss"] == claims["sub"] == client_id
142143

143144
cert.public_key().verify(signature, signed_part.encode("utf-8"), padding.PKCS1v15(), hashes.SHA256())
145+
146+
147+
@pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS)
148+
def test_enable_persistent_cache(cert_path, cert_password):
149+
"""the credential should use the persistent cache only when given enable_persistent_cache=True"""
150+
151+
persistent_cache = "azure.identity._internal.persistent_cache"
152+
required_arguments = ("tenant-id", "client-id", cert_path)
153+
154+
# credential should default to an in memory cache
155+
raise_when_called = Mock(side_effect=Exception("credential shouldn't attempt to load a persistent cache"))
156+
with patch(persistent_cache + "._load_persistent_cache", raise_when_called):
157+
CertificateCredential(*required_arguments, password=cert_password)
158+
159+
# allowing an unencrypted cache doesn't count as opting in to the persistent cache
160+
CertificateCredential(*required_arguments, password=cert_password, allow_unencrypted_cache=True)
161+
162+
# keyword argument opts in to persistent cache
163+
with patch(persistent_cache + ".msal_extensions") as mock_extensions:
164+
CertificateCredential(*required_arguments, password=cert_password, enable_persistent_cache=True)
165+
assert mock_extensions.PersistedTokenCache.call_count == 1
166+
167+
# opting in on an unsupported platform raises an exception
168+
with patch(persistent_cache + ".sys.platform", "commodore64"):
169+
with pytest.raises(NotImplementedError):
170+
CertificateCredential(*required_arguments, password=cert_password, enable_persistent_cache=True)
171+
with pytest.raises(NotImplementedError):
172+
CertificateCredential(
173+
*required_arguments, password=cert_password, enable_persistent_cache=True, allow_unencrypted_cache=True
174+
)
175+
176+
177+
@patch("azure.identity._internal.persistent_cache.sys.platform", "linux2")
178+
@patch("azure.identity._internal.persistent_cache.msal_extensions")
179+
@pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS)
180+
def test_persistent_cache_linux(mock_extensions, cert_path, cert_password):
181+
"""The credential should use an unencrypted cache when encryption is unavailable and the user explicitly opts in.
182+
183+
This test was written when Linux was the only platform on which encryption may not be available.
184+
"""
185+
186+
required_arguments = ("tenant-id", "client-id", cert_path)
187+
188+
# the credential should prefer an encrypted cache even when the user allows an unencrypted one
189+
CertificateCredential(
190+
*required_arguments, password=cert_password, enable_persistent_cache=True, allow_unencrypted_cache=True
191+
)
192+
assert mock_extensions.PersistedTokenCache.called_with(mock_extensions.LibsecretPersistence)
193+
mock_extensions.PersistedTokenCache.reset_mock()
194+
195+
# (when LibsecretPersistence's dependencies aren't available, constructing it raises ImportError)
196+
mock_extensions.LibsecretPersistence = Mock(side_effect=ImportError)
197+
198+
# encryption unavailable, no opt in to unencrypted cache -> credential should raise
199+
with pytest.raises(ValueError):
200+
CertificateCredential(*required_arguments, password=cert_password, enable_persistent_cache=True)
201+
202+
CertificateCredential(
203+
*required_arguments, password=cert_password, enable_persistent_cache=True, allow_unencrypted_cache=True
204+
)
205+
assert mock_extensions.PersistedTokenCache.called_with(mock_extensions.FilePersistence)
206+
207+
208+
@pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS)
209+
def test_persistent_cache_multiple_clients(cert_path, cert_password):
210+
"""the credential shouldn't use tokens issued to other service principals"""
211+
212+
access_token_a = "token a"
213+
access_token_b = "not " + access_token_a
214+
transport_a = validating_transport(
215+
requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token=access_token_a))]
216+
)
217+
transport_b = validating_transport(
218+
requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token=access_token_b))]
219+
)
220+
221+
cache = TokenCache()
222+
with patch("azure.identity._internal.persistent_cache._load_persistent_cache") as mock_cache_loader:
223+
mock_cache_loader.return_value = Mock(wraps=cache)
224+
credential_a = CertificateCredential(
225+
"tenant", "client-a", cert_path, password=cert_password, enable_persistent_cache=True, transport=transport_a
226+
)
227+
assert mock_cache_loader.call_count == 1, "credential should load the persistent cache"
228+
credential_b = CertificateCredential(
229+
"tenant", "client-b", cert_path, password=cert_password, enable_persistent_cache=True, transport=transport_b
230+
)
231+
assert mock_cache_loader.call_count == 2, "credential should load the persistent cache"
232+
233+
# A caches a token
234+
scope = "scope"
235+
token_a = credential_a.get_token(scope)
236+
assert token_a.token == access_token_a
237+
assert transport_a.send.call_count == 1
238+
239+
# B should get a different token for the same scope
240+
token_b = credential_b.get_token(scope)
241+
assert token_b.token == access_token_b
242+
assert transport_b.send.call_count == 1

0 commit comments

Comments
 (0)