|
7 | 7 | from unittest.mock import Mock, patch
|
8 | 8 | from urllib.parse import urlparse
|
9 | 9 |
|
10 |
| -from azure.identity import KnownAuthorities |
| 10 | +from azure.core.credentials import AccessToken |
| 11 | +from azure.identity import CredentialUnavailableError, KnownAuthorities |
11 | 12 | from azure.identity.aio import DefaultAzureCredential, SharedTokenCacheCredential
|
12 | 13 | from azure.identity.aio._credentials.azure_cli import AzureCliCredential
|
13 | 14 | from azure.identity.aio._credentials.managed_identity import ManagedIdentityCredential
|
14 | 15 | from azure.identity._constants import EnvironmentVariables
|
15 | 16 | import pytest
|
16 | 17 |
|
17 | 18 | from helpers import mock_response, Request
|
18 |
| -from helpers_async import async_validating_transport, wrap_in_future |
| 19 | +from helpers_async import async_validating_transport, get_completed_future, wrap_in_future |
19 | 20 | from test_shared_cache_credential import build_aad_response, get_account_event, populated_cache
|
20 | 21 |
|
21 | 22 |
|
| 23 | +@pytest.mark.asyncio |
| 24 | +async def test_iterates_only_once(): |
| 25 | + """When a credential succeeds, DefaultAzureCredential should use that credential thereafter, ignoring the others""" |
| 26 | + |
| 27 | + unavailable_credential = Mock(get_token=Mock(side_effect=CredentialUnavailableError(message="..."))) |
| 28 | + successful_credential = Mock(get_token=Mock(return_value=get_completed_future(AccessToken("***", 42)))) |
| 29 | + |
| 30 | + credential = DefaultAzureCredential() |
| 31 | + credential.credentials = [ |
| 32 | + unavailable_credential, |
| 33 | + successful_credential, |
| 34 | + Mock(get_token=Mock(side_effect=Exception("iteration didn't stop after a credential provided a token"))), |
| 35 | + ] |
| 36 | + |
| 37 | + for n in range(3): |
| 38 | + await credential.get_token("scope") |
| 39 | + assert unavailable_credential.get_token.call_count == 1 |
| 40 | + assert successful_credential.get_token.call_count == n + 1 |
| 41 | + |
| 42 | + |
22 | 43 | @pytest.mark.asyncio
|
23 | 44 | async def test_default_credential_authority():
|
24 | 45 | authority = "authority.com"
|
|
0 commit comments