Skip to content

Prevent DefaultAzureCredential changing authentication method #10349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Release History

## 1.4.0b2 (Unreleased)
- After an instance of `DefaultAzureCredential` successfully authenticates, it
uses the same authentication method for every subsequent token request. This
makes subsequent requests more efficient, and prevents unexpected changes of
authentication method.
([#10349](https://github.com/Azure/azure-sdk-for-python/pull/10349))
- All `get_token` methods consistently require at least one scope argument,
raising an error when none is passed. Although `get_token()` may sometimes
have succeeded in prior versions, it couldn't do so consistently because its
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any
from typing import Any, Optional
from azure.core.credentials import AccessToken, TokenCredential


Expand Down Expand Up @@ -44,6 +44,8 @@ def __init__(self, *credentials):
# type: (*TokenCredential) -> None
if not credentials:
raise ValueError("at least one credential is required")

self._successful_credential = None # type: Optional[TokenCredential]
self.credentials = credentials

def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
Expand All @@ -58,7 +60,9 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
history = []
for credential in self.credentials:
try:
return credential.get_token(*scopes, **kwargs)
token = credential.get_token(*scopes, **kwargs)
self._successful_credential = credential
return token
except CredentialUnavailableError as ex:
# credential didn't attempt authentication because it lacks required data or state -> continue
history.append((credential, ex.message))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def get_token(self, *scopes, **kwargs):
:raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The exception has a
`message` attribute listing each authentication attempt and its error message.
"""
if self._successful_credential:
return self._successful_credential.get_token(*scopes, **kwargs)

try:
return super(DefaultAzureCredential, self).get_token(*scopes, **kwargs)
except ClientAuthenticationError as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..._credentials.chained import _get_error_message

if TYPE_CHECKING:
from typing import Any
from typing import Any, Optional
from azure.core.credentials import AccessToken
from azure.core.credentials_async import AsyncTokenCredential

Expand All @@ -29,6 +29,8 @@ class ChainedTokenCredential(AsyncCredentialBase):
def __init__(self, *credentials: "AsyncTokenCredential") -> None:
if not credentials:
raise ValueError("at least one credential is required")

self._successful_credential = None # type: Optional[AsyncTokenCredential]
self.credentials = credentials

async def close(self):
Expand All @@ -50,7 +52,9 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
history = []
for credential in self.credentials:
try:
return await credential.get_token(*scopes, **kwargs)
token = await credential.get_token(*scopes, **kwargs)
self._successful_credential = credential
return token
except CredentialUnavailableError as ex:
# credential didn't attempt authentication because it lacks required data or state -> continue
history.append((credential, ex.message))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ async def get_token(self, *scopes: str, **kwargs: "Any"):
:raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The exception has a
`message` attribute listing each authentication attempt and its error message.
"""
if self._successful_credential:
return await self._successful_credential.get_token(*scopes, **kwargs)

try:
return await super(DefaultAzureCredential, self).get_token(*scopes, **kwargs)
except ClientAuthenticationError as e:
Expand Down
21 changes: 21 additions & 0 deletions sdk/identity/azure-identity/tests/test_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
# ------------------------------------
import os

from azure.core.credentials import AccessToken
from azure.identity import (
CredentialUnavailableError,
DefaultAzureCredential,
InteractiveBrowserCredential,
KnownAuthorities,
Expand All @@ -24,6 +26,25 @@
from mock import Mock, patch # type: ignore


def test_iterates_only_once():
"""When a credential succeeds, DefaultAzureCredential should use that credential thereafter, ignoring the others"""

unavailable_credential = Mock(get_token=Mock(side_effect=CredentialUnavailableError(message="...")))
successful_credential = Mock(get_token=Mock(return_value=AccessToken("***", 42)))

credential = DefaultAzureCredential()
credential.credentials = [
unavailable_credential,
successful_credential,
Mock(get_token=Mock(side_effect=Exception("iteration didn't stop after a credential provided a token"))),
]

for n in range(3):
credential.get_token("scope")
assert unavailable_credential.get_token.call_count == 1
assert successful_credential.get_token.call_count == n + 1


def test_default_credential_authority():
expected_access_token = "***"
response = mock_response(
Expand Down
25 changes: 23 additions & 2 deletions sdk/identity/azure-identity/tests/test_default_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,39 @@
from unittest.mock import Mock, patch
from urllib.parse import urlparse

from azure.identity import KnownAuthorities
from azure.core.credentials import AccessToken
from azure.identity import CredentialUnavailableError, KnownAuthorities
from azure.identity.aio import DefaultAzureCredential, SharedTokenCacheCredential
from azure.identity.aio._credentials.azure_cli import AzureCliCredential
from azure.identity.aio._credentials.managed_identity import ManagedIdentityCredential
from azure.identity._constants import EnvironmentVariables
import pytest

from helpers import mock_response, Request
from helpers_async import async_validating_transport, wrap_in_future
from helpers_async import async_validating_transport, get_completed_future, wrap_in_future
from test_shared_cache_credential import build_aad_response, get_account_event, populated_cache


@pytest.mark.asyncio
async def test_iterates_only_once():
"""When a credential succeeds, DefaultAzureCredential should use that credential thereafter, ignoring the others"""

unavailable_credential = Mock(get_token=Mock(side_effect=CredentialUnavailableError(message="...")))
successful_credential = Mock(get_token=Mock(return_value=get_completed_future(AccessToken("***", 42))))

credential = DefaultAzureCredential()
credential.credentials = [
unavailable_credential,
successful_credential,
Mock(get_token=Mock(side_effect=Exception("iteration didn't stop after a credential provided a token"))),
]

for n in range(3):
await credential.get_token("scope")
assert unavailable_credential.get_token.call_count == 1
assert successful_credential.get_token.call_count == n + 1


@pytest.mark.asyncio
async def test_default_credential_authority():
authority = "authority.com"
Expand Down