Skip to content

Commit 102be79

Browse files
authored
Move InteractiveCredential to a new module (#12706)
1 parent 612d7bc commit 102be79

File tree

13 files changed

+214
-195
lines changed

13 files changed

+214
-195
lines changed

sdk/identity/azure-identity/azure/identity/_credentials/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .azure_cli import AzureCliCredential
1515
from .device_code import DeviceCodeCredential
1616
from .user_password import UsernamePasswordCredential
17-
from .vscode_credential import VSCodeCredential
17+
from .vscode import VSCodeCredential
1818

1919

2020
__all__ = [

sdk/identity/azure-identity/azure/identity/_credentials/default.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .managed_identity import ManagedIdentityCredential
1414
from .shared_cache import SharedTokenCacheCredential
1515
from .azure_cli import AzureCliCredential
16-
from .vscode_credential import VSCodeCredential
16+
from .vscode import VSCodeCredential
1717

1818

1919
try:

sdk/identity/azure-identity/azure/identity/_internal/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def get_default_authority():
3737
from .certificate_credential_base import CertificateCredentialBase
3838
from .client_secret_credential_base import ClientSecretCredentialBase
3939
from .decorators import wrap_exceptions
40-
from .msal_credentials import InteractiveCredential
40+
from .interactive import InteractiveCredential
4141

4242

4343
def _scopes_to_resource(*scopes):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
"""Base class for credentials using MSAL for interactive user authentication"""
6+
7+
import abc
8+
import base64
9+
import json
10+
import logging
11+
import time
12+
from typing import TYPE_CHECKING
13+
14+
import msal
15+
from six.moves.urllib_parse import urlparse
16+
from azure.core.credentials import AccessToken
17+
from azure.core.exceptions import ClientAuthenticationError
18+
19+
from .msal_credentials import MsalCredential
20+
from .._auth_record import AuthenticationRecord
21+
from .._constants import KnownAuthorities
22+
from .._exceptions import AuthenticationRequiredError, CredentialUnavailableError
23+
from .._internal import wrap_exceptions
24+
25+
if TYPE_CHECKING:
26+
# pylint:disable=ungrouped-imports,unused-import
27+
from typing import Any, Optional
28+
29+
_LOGGER = logging.getLogger(__name__)
30+
31+
_DEFAULT_AUTHENTICATE_SCOPES = {
32+
"https://" + KnownAuthorities.AZURE_CHINA: ("https://management.core.chinacloudapi.cn//.default",),
33+
"https://" + KnownAuthorities.AZURE_GERMANY: ("https://management.core.cloudapi.de//.default",),
34+
"https://" + KnownAuthorities.AZURE_GOVERNMENT: ("https://management.core.usgovcloudapi.net//.default",),
35+
"https://" + KnownAuthorities.AZURE_PUBLIC_CLOUD: ("https://management.core.windows.net//.default",),
36+
}
37+
38+
39+
def _decode_client_info(raw):
40+
"""Taken from msal.oauth2cli.oidc"""
41+
42+
raw += "=" * (-len(raw) % 4)
43+
raw = str(raw) # On Python 2.7, argument of urlsafe_b64decode must be str, not unicode.
44+
return base64.urlsafe_b64decode(raw).decode("utf-8")
45+
46+
47+
def _build_auth_record(response):
48+
"""Build an AuthenticationRecord from the result of an MSAL ClientApplication token request"""
49+
50+
try:
51+
id_token = response["id_token_claims"]
52+
53+
if "client_info" in response:
54+
client_info = json.loads(_decode_client_info(response["client_info"]))
55+
home_account_id = "{uid}.{utid}".format(**client_info)
56+
else:
57+
# MSAL uses the subject claim as home_account_id when the STS doesn't provide client_info
58+
home_account_id = id_token["sub"]
59+
60+
return AuthenticationRecord(
61+
authority=urlparse(id_token["iss"]).netloc, # "iss" is the URL of the issuing tenant
62+
client_id=id_token["aud"],
63+
home_account_id=home_account_id,
64+
tenant_id=id_token["tid"], # tenant which issued the token, not necessarily user's home tenant
65+
username=id_token["preferred_username"],
66+
)
67+
except (KeyError, ValueError):
68+
# surprising: msal.ClientApplication always requests an id token, whose shape shouldn't change
69+
return None
70+
71+
72+
class InteractiveCredential(MsalCredential):
73+
def __init__(self, **kwargs):
74+
self._disable_automatic_authentication = kwargs.pop("disable_automatic_authentication", False)
75+
self._auth_record = kwargs.pop("authentication_record", None) # type: Optional[AuthenticationRecord]
76+
if self._auth_record:
77+
kwargs.pop("client_id", None) # authentication_record overrides client_id argument
78+
tenant_id = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id
79+
super(InteractiveCredential, self).__init__(
80+
client_id=self._auth_record.client_id,
81+
authority=self._auth_record.authority,
82+
tenant_id=tenant_id,
83+
**kwargs
84+
)
85+
else:
86+
super(InteractiveCredential, self).__init__(**kwargs)
87+
88+
def get_token(self, *scopes, **kwargs):
89+
# type: (*str, **Any) -> AccessToken
90+
"""Request an access token for `scopes`.
91+
92+
.. note:: This method is called by Azure SDK clients. It isn't intended for use in application code.
93+
94+
:param str scopes: desired scopes for the access token. This method requires at least one scope.
95+
:rtype: :class:`azure.core.credentials.AccessToken`
96+
:raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks
97+
required data, state, or platform support
98+
:raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message``
99+
attribute gives a reason.
100+
:raises AuthenticationRequiredError: user interaction is necessary to acquire a token, and the credential is
101+
configured not to begin this automatically. Call :func:`authenticate` to begin interactive authentication.
102+
"""
103+
if not scopes:
104+
message = "'get_token' requires at least one scope"
105+
_LOGGER.warning("%s.get_token failed: %s", self.__class__.__name__, message)
106+
raise ValueError(message)
107+
108+
allow_prompt = kwargs.pop("_allow_prompt", not self._disable_automatic_authentication)
109+
try:
110+
token = self._acquire_token_silent(*scopes, **kwargs)
111+
_LOGGER.info("%s.get_token succeeded", self.__class__.__name__)
112+
return token
113+
except Exception as ex: # pylint:disable=broad-except
114+
if not (isinstance(ex, AuthenticationRequiredError) and allow_prompt):
115+
_LOGGER.warning(
116+
"%s.get_token failed: %s",
117+
self.__class__.__name__,
118+
ex,
119+
exc_info=_LOGGER.isEnabledFor(logging.DEBUG),
120+
)
121+
raise
122+
123+
# silent authentication failed -> authenticate interactively
124+
now = int(time.time())
125+
126+
try:
127+
result = self._request_token(*scopes, **kwargs)
128+
if "access_token" not in result:
129+
message = "Authentication failed: {}".format(result.get("error_description") or result.get("error"))
130+
raise ClientAuthenticationError(message=message)
131+
132+
# this may be the first authentication, or the user may have authenticated a different identity
133+
self._auth_record = _build_auth_record(result)
134+
except Exception as ex: # pylint:disable=broad-except
135+
_LOGGER.warning(
136+
"%s.get_token failed: %s", self.__class__.__name__, ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG),
137+
)
138+
raise
139+
140+
_LOGGER.info("%s.get_token succeeded", self.__class__.__name__)
141+
return AccessToken(result["access_token"], now + int(result["expires_in"]))
142+
143+
def authenticate(self, **kwargs):
144+
# type: (**Any) -> AuthenticationRecord
145+
"""Interactively authenticate a user.
146+
147+
:keyword Iterable[str] scopes: scopes to request during authentication, such as those provided by
148+
:func:`AuthenticationRequiredError.scopes`. If provided, successful authentication will cache an access token
149+
for these scopes.
150+
:rtype: ~azure.identity.AuthenticationRecord
151+
:raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message``
152+
attribute gives a reason.
153+
"""
154+
155+
scopes = kwargs.pop("scopes", None)
156+
if not scopes:
157+
if self._authority not in _DEFAULT_AUTHENTICATE_SCOPES:
158+
# the credential is configured to use a cloud whose ARM scope we can't determine
159+
raise CredentialUnavailableError(
160+
message="Authenticating in this environment requires a value for the 'scopes' keyword argument."
161+
)
162+
163+
scopes = _DEFAULT_AUTHENTICATE_SCOPES[self._authority]
164+
165+
_ = self.get_token(*scopes, _allow_prompt=True, **kwargs)
166+
return self._auth_record # type: ignore
167+
168+
@wrap_exceptions
169+
def _acquire_token_silent(self, *scopes, **kwargs):
170+
# type: (*str, **Any) -> AccessToken
171+
result = None
172+
if self._auth_record:
173+
app = self._get_app()
174+
for account in app.get_accounts(username=self._auth_record.username):
175+
if account.get("home_account_id") != self._auth_record.home_account_id:
176+
continue
177+
178+
now = int(time.time())
179+
result = app.acquire_token_silent_with_error(list(scopes), account=account, **kwargs)
180+
if result and "access_token" in result and "expires_in" in result:
181+
return AccessToken(result["access_token"], now + int(result["expires_in"]))
182+
183+
# if we get this far, result is either None or the content of an AAD error response
184+
if result:
185+
details = result.get("error_description") or result.get("error")
186+
raise AuthenticationRequiredError(scopes, error_details=details)
187+
raise AuthenticationRequiredError(scopes)
188+
189+
def _get_app(self):
190+
# type: () -> msal.PublicClientApplication
191+
if not self._msal_app:
192+
self._msal_app = self._create_app(msal.PublicClientApplication)
193+
return self._msal_app
194+
195+
@abc.abstractmethod
196+
def _request_token(self, *scopes, **kwargs):
197+
pass

0 commit comments

Comments
 (0)