diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 0790a1560121..d59d8322689f 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -1,6 +1,11 @@ # Release History ## 1.4.0b2 (Unreleased) +- After an instance of `DefaultAzureCredential` successfully authenticates, it +uses the same authentication method for every subsequent token request. This +makes subsequent requests more efficient, and prevents unexpected changes of +authentication method. +([#10349](https://github.com/Azure/azure-sdk-for-python/pull/10349)) - All `get_token` methods consistently require at least one scope argument, raising an error when none is passed. Although `get_token()` may sometimes have succeeded in prior versions, it couldn't do so consistently because its diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/chained.py b/sdk/identity/azure-identity/azure/identity/_credentials/chained.py index b74b7c9619e8..8ec61663fcfc 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/chained.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/chained.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports - from typing import Any + from typing import Any, Optional from azure.core.credentials import AccessToken, TokenCredential @@ -44,6 +44,8 @@ def __init__(self, *credentials): # type: (*TokenCredential) -> None if not credentials: raise ValueError("at least one credential is required") + + self._successful_credential = None # type: Optional[TokenCredential] self.credentials = credentials def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument @@ -58,7 +60,9 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument history = [] for credential in self.credentials: try: - return credential.get_token(*scopes, **kwargs) + token = credential.get_token(*scopes, **kwargs) + self._successful_credential = credential + return token except CredentialUnavailableError as ex: # credential didn't attempt authentication because it lacks required data or state -> continue history.append((credential, ex.message)) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/_credentials/default.py index 08657a020125..061c26c6579b 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/default.py @@ -108,6 +108,9 @@ def get_token(self, *scopes, **kwargs): :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The exception has a `message` attribute listing each authentication attempt and its error message. """ + if self._successful_credential: + return self._successful_credential.get_token(*scopes, **kwargs) + try: return super(DefaultAzureCredential, self).get_token(*scopes, **kwargs) except ClientAuthenticationError as e: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py index 23ba1ca2f326..72d68376f2cd 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py @@ -11,7 +11,7 @@ from ..._credentials.chained import _get_error_message if TYPE_CHECKING: - from typing import Any + from typing import Any, Optional from azure.core.credentials import AccessToken from azure.core.credentials_async import AsyncTokenCredential @@ -29,6 +29,8 @@ class ChainedTokenCredential(AsyncCredentialBase): def __init__(self, *credentials: "AsyncTokenCredential") -> None: if not credentials: raise ValueError("at least one credential is required") + + self._successful_credential = None # type: Optional[AsyncTokenCredential] self.credentials = credentials async def close(self): @@ -50,7 +52,9 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": history = [] for credential in self.credentials: try: - return await credential.get_token(*scopes, **kwargs) + token = await credential.get_token(*scopes, **kwargs) + self._successful_credential = credential + return token except CredentialUnavailableError as ex: # credential didn't attempt authentication because it lacks required data or state -> continue history.append((credential, ex.message)) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py index 94b03aba0cb7..16b4ae608bcc 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py @@ -95,6 +95,9 @@ async def get_token(self, *scopes: str, **kwargs: "Any"): :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The exception has a `message` attribute listing each authentication attempt and its error message. """ + if self._successful_credential: + return await self._successful_credential.get_token(*scopes, **kwargs) + try: return await super(DefaultAzureCredential, self).get_token(*scopes, **kwargs) except ClientAuthenticationError as e: diff --git a/sdk/identity/azure-identity/tests/test_default.py b/sdk/identity/azure-identity/tests/test_default.py index e7b1b98353a0..b3ad7542b996 100644 --- a/sdk/identity/azure-identity/tests/test_default.py +++ b/sdk/identity/azure-identity/tests/test_default.py @@ -4,7 +4,9 @@ # ------------------------------------ import os +from azure.core.credentials import AccessToken from azure.identity import ( + CredentialUnavailableError, DefaultAzureCredential, InteractiveBrowserCredential, KnownAuthorities, @@ -24,6 +26,25 @@ from mock import Mock, patch # type: ignore +def test_iterates_only_once(): + """When a credential succeeds, DefaultAzureCredential should use that credential thereafter, ignoring the others""" + + unavailable_credential = Mock(get_token=Mock(side_effect=CredentialUnavailableError(message="..."))) + successful_credential = Mock(get_token=Mock(return_value=AccessToken("***", 42))) + + credential = DefaultAzureCredential() + credential.credentials = [ + unavailable_credential, + successful_credential, + Mock(get_token=Mock(side_effect=Exception("iteration didn't stop after a credential provided a token"))), + ] + + for n in range(3): + credential.get_token("scope") + assert unavailable_credential.get_token.call_count == 1 + assert successful_credential.get_token.call_count == n + 1 + + def test_default_credential_authority(): expected_access_token = "***" response = mock_response( diff --git a/sdk/identity/azure-identity/tests/test_default_async.py b/sdk/identity/azure-identity/tests/test_default_async.py index 6443e08c372a..27fa7744db72 100644 --- a/sdk/identity/azure-identity/tests/test_default_async.py +++ b/sdk/identity/azure-identity/tests/test_default_async.py @@ -7,7 +7,8 @@ from unittest.mock import Mock, patch from urllib.parse import urlparse -from azure.identity import KnownAuthorities +from azure.core.credentials import AccessToken +from azure.identity import CredentialUnavailableError, KnownAuthorities from azure.identity.aio import DefaultAzureCredential, SharedTokenCacheCredential from azure.identity.aio._credentials.azure_cli import AzureCliCredential from azure.identity.aio._credentials.managed_identity import ManagedIdentityCredential @@ -15,10 +16,30 @@ import pytest from helpers import mock_response, Request -from helpers_async import async_validating_transport, wrap_in_future +from helpers_async import async_validating_transport, get_completed_future, wrap_in_future from test_shared_cache_credential import build_aad_response, get_account_event, populated_cache +@pytest.mark.asyncio +async def test_iterates_only_once(): + """When a credential succeeds, DefaultAzureCredential should use that credential thereafter, ignoring the others""" + + unavailable_credential = Mock(get_token=Mock(side_effect=CredentialUnavailableError(message="..."))) + successful_credential = Mock(get_token=Mock(return_value=get_completed_future(AccessToken("***", 42)))) + + credential = DefaultAzureCredential() + credential.credentials = [ + unavailable_credential, + successful_credential, + Mock(get_token=Mock(side_effect=Exception("iteration didn't stop after a credential provided a token"))), + ] + + for n in range(3): + await credential.get_token("scope") + assert unavailable_credential.get_token.call_count == 1 + assert successful_credential.get_token.call_count == n + 1 + + @pytest.mark.asyncio async def test_default_credential_authority(): authority = "authority.com"