Skip to content

Commit cf1c244

Browse files
authored
Prevent DefaultAzureCredential changing authentication method (#10349)
1 parent 84ecb1b commit cf1c244

File tree

7 files changed

+67
-6
lines changed

7 files changed

+67
-6
lines changed

sdk/identity/azure-identity/CHANGELOG.md

+5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
# Release History
22

33
## 1.4.0b2 (Unreleased)
4+
- After an instance of `DefaultAzureCredential` successfully authenticates, it
5+
uses the same authentication method for every subsequent token request. This
6+
makes subsequent requests more efficient, and prevents unexpected changes of
7+
authentication method.
8+
([#10349](https://github.com/Azure/azure-sdk-for-python/pull/10349))
49
- All `get_token` methods consistently require at least one scope argument,
510
raising an error when none is passed. Although `get_token()` may sometimes
611
have succeeded in prior versions, it couldn't do so consistently because its

sdk/identity/azure-identity/azure/identity/_credentials/chained.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
if TYPE_CHECKING:
1515
# pylint:disable=unused-import,ungrouped-imports
16-
from typing import Any
16+
from typing import Any, Optional
1717
from azure.core.credentials import AccessToken, TokenCredential
1818

1919

@@ -44,6 +44,8 @@ def __init__(self, *credentials):
4444
# type: (*TokenCredential) -> None
4545
if not credentials:
4646
raise ValueError("at least one credential is required")
47+
48+
self._successful_credential = None # type: Optional[TokenCredential]
4749
self.credentials = credentials
4850

4951
def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
@@ -58,7 +60,9 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
5860
history = []
5961
for credential in self.credentials:
6062
try:
61-
return credential.get_token(*scopes, **kwargs)
63+
token = credential.get_token(*scopes, **kwargs)
64+
self._successful_credential = credential
65+
return token
6266
except CredentialUnavailableError as ex:
6367
# credential didn't attempt authentication because it lacks required data or state -> continue
6468
history.append((credential, ex.message))

sdk/identity/azure-identity/azure/identity/_credentials/default.py

+3
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ def get_token(self, *scopes, **kwargs):
108108
:raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The exception has a
109109
`message` attribute listing each authentication attempt and its error message.
110110
"""
111+
if self._successful_credential:
112+
return self._successful_credential.get_token(*scopes, **kwargs)
113+
111114
try:
112115
return super(DefaultAzureCredential, self).get_token(*scopes, **kwargs)
113116
except ClientAuthenticationError as e:

sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ..._credentials.chained import _get_error_message
1212

1313
if TYPE_CHECKING:
14-
from typing import Any
14+
from typing import Any, Optional
1515
from azure.core.credentials import AccessToken
1616
from azure.core.credentials_async import AsyncTokenCredential
1717

@@ -29,6 +29,8 @@ class ChainedTokenCredential(AsyncCredentialBase):
2929
def __init__(self, *credentials: "AsyncTokenCredential") -> None:
3030
if not credentials:
3131
raise ValueError("at least one credential is required")
32+
33+
self._successful_credential = None # type: Optional[AsyncTokenCredential]
3234
self.credentials = credentials
3335

3436
async def close(self):
@@ -50,7 +52,9 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
5052
history = []
5153
for credential in self.credentials:
5254
try:
53-
return await credential.get_token(*scopes, **kwargs)
55+
token = await credential.get_token(*scopes, **kwargs)
56+
self._successful_credential = credential
57+
return token
5458
except CredentialUnavailableError as ex:
5559
# credential didn't attempt authentication because it lacks required data or state -> continue
5660
history.append((credential, ex.message))

sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py

+3
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ async def get_token(self, *scopes: str, **kwargs: "Any"):
9595
:raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The exception has a
9696
`message` attribute listing each authentication attempt and its error message.
9797
"""
98+
if self._successful_credential:
99+
return await self._successful_credential.get_token(*scopes, **kwargs)
100+
98101
try:
99102
return await super(DefaultAzureCredential, self).get_token(*scopes, **kwargs)
100103
except ClientAuthenticationError as e:

sdk/identity/azure-identity/tests/test_default.py

+21
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
# ------------------------------------
55
import os
66

7+
from azure.core.credentials import AccessToken
78
from azure.identity import (
9+
CredentialUnavailableError,
810
DefaultAzureCredential,
911
InteractiveBrowserCredential,
1012
KnownAuthorities,
@@ -24,6 +26,25 @@
2426
from mock import Mock, patch # type: ignore
2527

2628

29+
def test_iterates_only_once():
30+
"""When a credential succeeds, DefaultAzureCredential should use that credential thereafter, ignoring the others"""
31+
32+
unavailable_credential = Mock(get_token=Mock(side_effect=CredentialUnavailableError(message="...")))
33+
successful_credential = Mock(get_token=Mock(return_value=AccessToken("***", 42)))
34+
35+
credential = DefaultAzureCredential()
36+
credential.credentials = [
37+
unavailable_credential,
38+
successful_credential,
39+
Mock(get_token=Mock(side_effect=Exception("iteration didn't stop after a credential provided a token"))),
40+
]
41+
42+
for n in range(3):
43+
credential.get_token("scope")
44+
assert unavailable_credential.get_token.call_count == 1
45+
assert successful_credential.get_token.call_count == n + 1
46+
47+
2748
def test_default_credential_authority():
2849
expected_access_token = "***"
2950
response = mock_response(

sdk/identity/azure-identity/tests/test_default_async.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,39 @@
77
from unittest.mock import Mock, patch
88
from urllib.parse import urlparse
99

10-
from azure.identity import KnownAuthorities
10+
from azure.core.credentials import AccessToken
11+
from azure.identity import CredentialUnavailableError, KnownAuthorities
1112
from azure.identity.aio import DefaultAzureCredential, SharedTokenCacheCredential
1213
from azure.identity.aio._credentials.azure_cli import AzureCliCredential
1314
from azure.identity.aio._credentials.managed_identity import ManagedIdentityCredential
1415
from azure.identity._constants import EnvironmentVariables
1516
import pytest
1617

1718
from helpers import mock_response, Request
18-
from helpers_async import async_validating_transport, wrap_in_future
19+
from helpers_async import async_validating_transport, get_completed_future, wrap_in_future
1920
from test_shared_cache_credential import build_aad_response, get_account_event, populated_cache
2021

2122

23+
@pytest.mark.asyncio
24+
async def test_iterates_only_once():
25+
"""When a credential succeeds, DefaultAzureCredential should use that credential thereafter, ignoring the others"""
26+
27+
unavailable_credential = Mock(get_token=Mock(side_effect=CredentialUnavailableError(message="...")))
28+
successful_credential = Mock(get_token=Mock(return_value=get_completed_future(AccessToken("***", 42))))
29+
30+
credential = DefaultAzureCredential()
31+
credential.credentials = [
32+
unavailable_credential,
33+
successful_credential,
34+
Mock(get_token=Mock(side_effect=Exception("iteration didn't stop after a credential provided a token"))),
35+
]
36+
37+
for n in range(3):
38+
await credential.get_token("scope")
39+
assert unavailable_credential.get_token.call_count == 1
40+
assert successful_credential.get_token.call_count == n + 1
41+
42+
2243
@pytest.mark.asyncio
2344
async def test_default_credential_authority():
2445
authority = "authority.com"

0 commit comments

Comments
 (0)