Skip to content

Commit 994c77d

Browse files
authored
SharedTokenCacheCredential takes an optional AuthenticationRecord (#11637)
1 parent b15cede commit 994c77d

File tree

6 files changed

+258
-14
lines changed

6 files changed

+258
-14
lines changed

sdk/identity/azure-identity/CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
the keyword argument `interactive_browser_tenant_id`, or set the environment
3030
variable `AZURE_TENANT_ID`.
3131
([#11548](https://github.com/Azure/azure-sdk-for-python/issues/11548))
32+
- `SharedTokenCacheCredential` can be initialized with an `AuthenticationRecord`
33+
provided by a user credential.
34+
([#11448](https://github.com/Azure/azure-sdk-for-python/issues/11448))
3235
- The user authentication API added to `DeviceCodeCredential` and
3336
`InteractiveBrowserCredential` in 1.4.0b3 is available on
3437
`UsernamePasswordCredential` as well.

sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414

1515
if TYPE_CHECKING:
1616
# pylint:disable=unused-import,ungrouped-imports
17-
from typing import Any, Mapping
18-
from azure.core.credentials import AccessToken
17+
from typing import Any
1918
from .._internal import AadClientBase
2019

2120

@@ -31,6 +30,8 @@ class SharedTokenCacheCredential(SharedTokenCacheBase):
3130
defines authorities for other clouds.
3231
:keyword str tenant_id: an Azure Active Directory tenant ID. Used to select an account when the cache contains
3332
tokens for multiple identities.
33+
:keyword AuthenticationRecord authentication_record: an authentication record returned by a user credential such as
34+
:class:`DeviceCodeCredential` or :class:`InteractiveBrowserCredential`
3435
"""
3536

3637
@wrap_exceptions
@@ -67,4 +68,4 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
6768

6869
def _get_auth_client(self, **kwargs):
6970
# type: (**Any) -> AadClientBase
70-
return AadClient(tenant_id="common", client_id=AZURE_CLI_CLIENT_ID, **kwargs)
71+
return AadClient(client_id=AZURE_CLI_CLIENT_ID, **kwargs)

sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py

+26-8
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
# pylint:disable=unused-import,ungrouped-imports
2828
from typing import Any, Iterable, List, Mapping, Optional
2929
from .._internal import AadClientBase
30+
from azure.identity import AuthenticationRecord
3031

3132
CacheItem = Mapping[str, str]
3233

@@ -86,13 +87,22 @@ class SharedTokenCacheBase(ABC):
8687
def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument
8788
# type: (Optional[str], **Any) -> None
8889

89-
authority = kwargs.pop("authority", None)
90-
self._authority = normalize_authority(authority) if authority else get_default_authority()
91-
92-
environment = urlparse(self._authority).netloc
93-
self._environment_aliases = KNOWN_ALIASES.get(environment) or frozenset((environment,))
94-
self._username = username
95-
self._tenant_id = kwargs.pop("tenant_id", None)
90+
self._auth_record = kwargs.pop("authentication_record", None) # type: Optional[AuthenticationRecord]
91+
if self._auth_record:
92+
# authenticate in the tenant that produced the record unless 'tenant_id' specifies another
93+
authenticating_tenant = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id
94+
self._tenant_id = self._auth_record.tenant_id
95+
self._authority = self._auth_record.authority
96+
self._username = self._auth_record.username
97+
self._environment_aliases = frozenset((self._authority,))
98+
else:
99+
authenticating_tenant = "organizations"
100+
authority = kwargs.pop("authority", None)
101+
self._authority = normalize_authority(authority) if authority else get_default_authority()
102+
environment = urlparse(self._authority).netloc
103+
self._environment_aliases = KNOWN_ALIASES.get(environment) or frozenset((environment,))
104+
self._username = username
105+
self._tenant_id = kwargs.pop("tenant_id", None)
96106

97107
cache = kwargs.pop("_cache", None) # for ease of testing
98108

@@ -110,7 +120,7 @@ def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument
110120
if cache:
111121
self._cache = cache
112122
self._client = self._get_auth_client(
113-
authority=self._authority, cache=cache, **kwargs
123+
authority=self._authority, tenant_id=authenticating_tenant, cache=cache, **kwargs
114124
) # type: Optional[AadClientBase]
115125
else:
116126
self._client = None
@@ -161,6 +171,14 @@ def _get_account(self, username=None, tenant_id=None):
161171
# cache is empty or contains no refresh token -> user needs to sign in
162172
raise CredentialUnavailableError(message=NO_ACCOUNTS)
163173

174+
if self._auth_record:
175+
for account in accounts:
176+
if account.get("home_account_id") == self._auth_record.home_account_id:
177+
return account
178+
raise CredentialUnavailableError(
179+
message="The cache contains no account matching the given AuthenticationRecord."
180+
)
181+
164182
filtered_accounts = _filtered_accounts(accounts, username, tenant_id)
165183
if len(filtered_accounts) == 1:
166184
return filtered_accounts[0]

sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class SharedTokenCacheCredential(SharedTokenCacheBase, AsyncCredentialBase):
2929
defines authorities for other clouds.
3030
:keyword str tenant_id: an Azure Active Directory tenant ID. Used to select an account when the cache contains
3131
tokens for multiple identities.
32+
:keyword AuthenticationRecord authentication_record: an authentication record returned by a user credential such as
33+
:class:`DeviceCodeCredential` or :class:`InteractiveBrowserCredential`
3234
"""
3335

3436
async def __aenter__(self):
@@ -74,4 +76,4 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py
7476
raise CredentialUnavailableError(message=NO_TOKEN.format(account.get("username")))
7577

7678
def _get_auth_client(self, **kwargs: "Any") -> "AadClientBase":
77-
return AadClient(tenant_id="common", client_id=AZURE_CLI_CLIENT_ID, **kwargs)
79+
return AadClient(client_id=AZURE_CLI_CLIENT_ID, **kwargs)

sdk/identity/azure-identity/tests/test_shared_cache_credential.py

+111-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
# ------------------------------------
55
from azure.core.exceptions import ClientAuthenticationError
66
from azure.core.pipeline.policies import SansIOHTTPPolicy
7-
from azure.identity import CredentialUnavailableError, KnownAuthorities, SharedTokenCacheCredential
7+
from azure.identity import (
8+
AuthenticationRecord,
9+
CredentialUnavailableError,
10+
SharedTokenCacheCredential,
11+
)
812
from azure.identity._constants import EnvironmentVariables
913
from azure.identity._internal.shared_token_cache import (
1014
KNOWN_ALIASES,
@@ -502,6 +506,112 @@ def test_authority_environment_variable():
502506
assert token.token == expected_access_token
503507

504508

509+
def test_authentication_record_empty_cache():
510+
record = AuthenticationRecord("tenant_id", "client_id", "authority", "home_account_id", "username")
511+
transport = Mock(side_effect=Exception("the credential shouldn't send a request"))
512+
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=TokenCache())
513+
514+
with pytest.raises(CredentialUnavailableError):
515+
credential.get_token("scope")
516+
517+
518+
def test_authentication_record_no_match():
519+
tenant_id = "tenant-id"
520+
client_id = "client-id"
521+
authority = "localhost"
522+
object_id = "object-id"
523+
home_account_id = object_id + "." + tenant_id
524+
username = "me"
525+
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)
526+
527+
transport = Mock(side_effect=Exception("the credential shouldn't send a request"))
528+
cache = populated_cache(
529+
get_account_event(
530+
"not-" + username, "not-" + object_id, "different-" + tenant_id, client_id="not-" + client_id,
531+
),
532+
)
533+
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)
534+
535+
with pytest.raises(CredentialUnavailableError):
536+
credential.get_token("scope")
537+
538+
539+
def test_authentication_record():
540+
tenant_id = "tenant-id"
541+
client_id = "client-id"
542+
authority = "localhost"
543+
object_id = "object-id"
544+
home_account_id = object_id + "." + tenant_id
545+
username = "me"
546+
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)
547+
548+
expected_access_token = "****"
549+
expected_refresh_token = "**"
550+
account = get_account_event(
551+
username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token
552+
)
553+
cache = populated_cache(account)
554+
555+
transport = validating_transport(
556+
requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})],
557+
responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))],
558+
)
559+
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)
560+
561+
token = credential.get_token("scope")
562+
assert token.token == expected_access_token
563+
564+
565+
def test_auth_record_multiple_accounts_for_username():
566+
tenant_id = "tenant-id"
567+
client_id = "client-id"
568+
authority = "localhost"
569+
object_id = "object-id"
570+
home_account_id = object_id + "." + tenant_id
571+
username = "me"
572+
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)
573+
574+
expected_access_token = "****"
575+
expected_refresh_token = "**"
576+
expected_account = get_account_event(
577+
username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token
578+
)
579+
cache = populated_cache(
580+
expected_account,
581+
get_account_event( # this account matches all but the record's tenant
582+
username,
583+
object_id,
584+
"different-" + tenant_id,
585+
authority=authority,
586+
client_id=client_id,
587+
refresh_token="not-" + expected_refresh_token,
588+
),
589+
)
590+
591+
transport = validating_transport(
592+
requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})],
593+
responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))],
594+
)
595+
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)
596+
597+
token = credential.get_token("scope")
598+
assert token.token == expected_access_token
599+
600+
601+
def test_authentication_record_authenticating_tenant():
602+
"""when given a record and 'tenant_id', the credential should authenticate in the latter"""
603+
604+
expected_tenant_id = "tenant-id"
605+
record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...")
606+
607+
with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client:
608+
SharedTokenCacheCredential(authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id)
609+
610+
assert get_auth_client.call_count == 1
611+
_, kwargs = get_auth_client.call_args
612+
assert kwargs["tenant_id"] == expected_tenant_id
613+
614+
505615
def get_account_event(
506616
username, uid, utid, authority=None, client_id="client-id", refresh_token="refresh-token", scopes=None
507617
):

sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py

+111-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from azure.core.exceptions import ClientAuthenticationError
99
from azure.core.pipeline.policies import SansIOHTTPPolicy
10-
from azure.identity import CredentialUnavailableError, KnownAuthorities
10+
from azure.identity import AuthenticationRecord, CredentialUnavailableError
1111
from azure.identity.aio import SharedTokenCacheCredential
1212
from azure.identity._constants import EnvironmentVariables
1313
from azure.identity._internal.shared_token_cache import (
@@ -566,3 +566,113 @@ async def test_authority_environment_variable():
566566
credential = SharedTokenCacheCredential(transport=transport, _cache=cache)
567567
token = await credential.get_token("scope")
568568
assert token.token == expected_access_token
569+
570+
571+
@pytest.mark.asyncio
572+
async def test_authentication_record_empty_cache():
573+
record = AuthenticationRecord("tenant_id", "client_id", "authority", "home_account_id", "username")
574+
transport = Mock(side_effect=Exception("the credential shouldn't send a request"))
575+
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=TokenCache())
576+
577+
with pytest.raises(CredentialUnavailableError):
578+
await credential.get_token("scope")
579+
580+
581+
@pytest.mark.asyncio
582+
async def test_authentication_record_no_match():
583+
tenant_id = "tenant-id"
584+
client_id = "client-id"
585+
authority = "localhost"
586+
object_id = "object-id"
587+
home_account_id = object_id + "." + tenant_id
588+
username = "me"
589+
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)
590+
591+
transport = Mock(side_effect=Exception("the credential shouldn't send a request"))
592+
cache = populated_cache(
593+
get_account_event(
594+
"not-" + username, "not-" + object_id, "different-" + tenant_id, client_id="not-" + client_id,
595+
),
596+
)
597+
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)
598+
599+
with pytest.raises(CredentialUnavailableError):
600+
await credential.get_token("scope")
601+
602+
603+
@pytest.mark.asyncio
604+
async def test_authentication_record():
605+
tenant_id = "tenant-id"
606+
client_id = "client-id"
607+
authority = "localhost"
608+
object_id = "object-id"
609+
home_account_id = object_id + "." + tenant_id
610+
username = "me"
611+
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)
612+
613+
expected_access_token = "****"
614+
expected_refresh_token = "**"
615+
account = get_account_event(
616+
username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token
617+
)
618+
cache = populated_cache(account)
619+
620+
transport = async_validating_transport(
621+
requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})],
622+
responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))],
623+
)
624+
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)
625+
626+
token = await credential.get_token("scope")
627+
assert token.token == expected_access_token
628+
629+
630+
@pytest.mark.asyncio
631+
async def test_auth_record_multiple_accounts_for_username():
632+
tenant_id = "tenant-id"
633+
client_id = "client-id"
634+
authority = "localhost"
635+
object_id = "object-id"
636+
home_account_id = object_id + "." + tenant_id
637+
username = "me"
638+
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)
639+
640+
expected_access_token = "****"
641+
expected_refresh_token = "**"
642+
expected_account = get_account_event(
643+
username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token
644+
)
645+
cache = populated_cache(
646+
expected_account,
647+
get_account_event( # this account matches all but the record's tenant
648+
username,
649+
object_id,
650+
"different-" + tenant_id,
651+
authority=authority,
652+
client_id=client_id,
653+
refresh_token="not-" + expected_refresh_token,
654+
),
655+
)
656+
657+
transport = async_validating_transport(
658+
requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})],
659+
responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))],
660+
)
661+
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)
662+
663+
token = await credential.get_token("scope")
664+
assert token.token == expected_access_token
665+
666+
667+
def test_authentication_record_authenticating_tenant():
668+
"""when given a record and 'tenant_id', the credential should authenticate in the latter"""
669+
670+
expected_tenant_id = "tenant-id"
671+
record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...")
672+
673+
with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client:
674+
SharedTokenCacheCredential(authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id)
675+
676+
assert get_auth_client.call_count == 1
677+
_, kwargs = get_auth_client.call_args
678+
assert kwargs["tenant_id"] == expected_tenant_id

0 commit comments

Comments
 (0)