Skip to content

Commit 2614fa1

Browse files
authored
Refactor SharedTokenCacheCredential exception handling (Azure#11962)
1 parent dc406b4 commit 2614fa1

File tree

6 files changed

+30
-21
lines changed

6 files changed

+30
-21
lines changed

sdk/identity/azure-identity/CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# Release History
22

33
## 1.4.0b5 (Unreleased)
4-
4+
- `SharedTokenCacheCredential.get_token` raises `ValueError` instead of
5+
`ClientAuthenticationError` when called with no scopes.
6+
([#11553](https://github.com/Azure/azure-sdk-for-python/issues/11553))
57

68
## 1.4.0b4 (2020-06-09)
79
- `ManagedIdentityCredential` can configure a user-assigned identity using any

sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# ------------------------------------
55
from .. import CredentialUnavailableError
66
from .._constants import AZURE_CLI_CLIENT_ID
7-
from .._internal import AadClient, wrap_exceptions
7+
from .._internal import AadClient
88
from .._internal.shared_token_cache import NO_TOKEN, SharedTokenCacheBase
99

1010
try:
@@ -36,7 +36,6 @@ class SharedTokenCacheCredential(SharedTokenCacheBase):
3636
is unavailable. Defaults to False.
3737
"""
3838

39-
@wrap_exceptions
4039
def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
4140
# type (*str, **Any) -> AccessToken
4241
"""Get an access token for `scopes` from the shared cache.

sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
import time
77

88
from msal import TokenCache
9+
import six
910
from six.moves.urllib_parse import urlparse
1011

1112
from azure.core.credentials import AccessToken
1213
from .. import CredentialUnavailableError
1314
from .._constants import KnownAuthorities
14-
from .._internal import get_default_authority, normalize_authority
15+
from .._internal import get_default_authority, normalize_authority, wrap_exceptions
1516
from .._internal.persistent_cache import load_user_cache
1617

1718
try:
@@ -158,6 +159,7 @@ def _get_accounts_having_matching_refresh_tokens(self):
158159
accounts[account["home_account_id"]] = account
159160
return accounts.values()
160161

162+
@wrap_exceptions
161163
def _get_account(self, username=None, tenant_id=None):
162164
# type: (Optional[str], Optional[str]) -> CacheItem
163165
"""returns exactly one account which has a refresh token and matches username and/or tenant_id"""
@@ -198,26 +200,34 @@ def _get_cached_access_token(self, scopes, account):
198200
if "home_account_id" not in account:
199201
return None
200202

201-
cache_entries = self._cache.find(
202-
TokenCache.CredentialType.ACCESS_TOKEN,
203-
target=list(scopes),
204-
query={"home_account_id": account["home_account_id"]},
205-
)
203+
try:
204+
cache_entries = self._cache.find(
205+
TokenCache.CredentialType.ACCESS_TOKEN,
206+
target=list(scopes),
207+
query={"home_account_id": account["home_account_id"]},
208+
)
209+
for token in cache_entries:
210+
expires_on = int(token["expires_on"])
211+
if expires_on - 300 > int(time.time()):
212+
return AccessToken(token["secret"], expires_on)
213+
except Exception as ex: # pylint:disable=broad-except
214+
message = "Error accessing cached data: {}".format(ex)
215+
six.raise_from(CredentialUnavailableError(message=message), ex)
206216

207-
for token in cache_entries:
208-
expires_on = int(token["expires_on"])
209-
if expires_on - 300 > int(time.time()):
210-
return AccessToken(token["secret"], expires_on)
211217
return None
212218

213219
def _get_refresh_tokens(self, account):
214220
if "home_account_id" not in account:
215221
return None
216222

217-
cache_entries = self._cache.find(
218-
TokenCache.CredentialType.REFRESH_TOKEN, query={"home_account_id": account["home_account_id"]}
219-
)
220-
return (token["secret"] for token in cache_entries if "secret" in token)
223+
try:
224+
cache_entries = self._cache.find(
225+
TokenCache.CredentialType.REFRESH_TOKEN, query={"home_account_id": account["home_account_id"]}
226+
)
227+
return [token["secret"] for token in cache_entries if "secret" in token]
228+
except Exception as ex: # pylint:disable=broad-except
229+
message = "Error accessing cached data: {}".format(ex)
230+
six.raise_from(CredentialUnavailableError(message=message), ex)
221231

222232
@staticmethod
223233
def supported():

sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from ..._constants import AZURE_CLI_CLIENT_ID
99
from ..._internal.shared_token_cache import NO_TOKEN, SharedTokenCacheBase
1010
from .._internal.aad_client import AadClient
11-
from .._internal.exception_wrapper import wrap_exceptions
1211
from .base import AsyncCredentialBase
1312

1413
if TYPE_CHECKING:
@@ -46,7 +45,6 @@ async def close(self):
4645
if self._client:
4746
await self._client.__aexit__()
4847

49-
@wrap_exceptions
5048
async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument
5149
"""Get an access token for `scopes` from the shared cache.
5250

sdk/identity/azure-identity/tests/test_shared_cache_credential.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_no_scopes():
3535
"""The credential should raise when get_token is called with no scopes"""
3636

3737
credential = SharedTokenCacheCredential(_cache=TokenCache())
38-
with pytest.raises(ClientAuthenticationError):
38+
with pytest.raises(ValueError):
3939
credential.get_token()
4040

4141

sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ async def test_no_scopes():
3131
"""The credential should raise when get_token is called with no scopes"""
3232

3333
credential = SharedTokenCacheCredential(_cache=TokenCache())
34-
with pytest.raises(ClientAuthenticationError):
34+
with pytest.raises(ValueError):
3535
await credential.get_token()
3636

3737

0 commit comments

Comments
 (0)