|
| 1 | +# ------------------------------------ |
| 2 | +# Copyright (c) Microsoft Corporation. |
| 3 | +# Licensed under the MIT License. |
| 4 | +# ------------------------------------ |
| 5 | +"""Base class for credentials using MSAL for interactive user authentication""" |
| 6 | + |
| 7 | +import abc |
| 8 | +import base64 |
| 9 | +import json |
| 10 | +import logging |
| 11 | +import time |
| 12 | +from typing import TYPE_CHECKING |
| 13 | + |
| 14 | +import msal |
| 15 | +from six.moves.urllib_parse import urlparse |
| 16 | +from azure.core.credentials import AccessToken |
| 17 | +from azure.core.exceptions import ClientAuthenticationError |
| 18 | + |
| 19 | +from .msal_credentials import MsalCredential |
| 20 | +from .._auth_record import AuthenticationRecord |
| 21 | +from .._constants import KnownAuthorities |
| 22 | +from .._exceptions import AuthenticationRequiredError, CredentialUnavailableError |
| 23 | +from .._internal import wrap_exceptions |
| 24 | + |
| 25 | +if TYPE_CHECKING: |
| 26 | + # pylint:disable=ungrouped-imports,unused-import |
| 27 | + from typing import Any, Optional |
| 28 | + |
| 29 | +_LOGGER = logging.getLogger(__name__) |
| 30 | + |
| 31 | +_DEFAULT_AUTHENTICATE_SCOPES = { |
| 32 | + "https://" + KnownAuthorities.AZURE_CHINA: ("https://management.core.chinacloudapi.cn//.default",), |
| 33 | + "https://" + KnownAuthorities.AZURE_GERMANY: ("https://management.core.cloudapi.de//.default",), |
| 34 | + "https://" + KnownAuthorities.AZURE_GOVERNMENT: ("https://management.core.usgovcloudapi.net//.default",), |
| 35 | + "https://" + KnownAuthorities.AZURE_PUBLIC_CLOUD: ("https://management.core.windows.net//.default",), |
| 36 | +} |
| 37 | + |
| 38 | + |
| 39 | +def _decode_client_info(raw): |
| 40 | + """Taken from msal.oauth2cli.oidc""" |
| 41 | + |
| 42 | + raw += "=" * (-len(raw) % 4) |
| 43 | + raw = str(raw) # On Python 2.7, argument of urlsafe_b64decode must be str, not unicode. |
| 44 | + return base64.urlsafe_b64decode(raw).decode("utf-8") |
| 45 | + |
| 46 | + |
| 47 | +def _build_auth_record(response): |
| 48 | + """Build an AuthenticationRecord from the result of an MSAL ClientApplication token request""" |
| 49 | + |
| 50 | + try: |
| 51 | + id_token = response["id_token_claims"] |
| 52 | + |
| 53 | + if "client_info" in response: |
| 54 | + client_info = json.loads(_decode_client_info(response["client_info"])) |
| 55 | + home_account_id = "{uid}.{utid}".format(**client_info) |
| 56 | + else: |
| 57 | + # MSAL uses the subject claim as home_account_id when the STS doesn't provide client_info |
| 58 | + home_account_id = id_token["sub"] |
| 59 | + |
| 60 | + return AuthenticationRecord( |
| 61 | + authority=urlparse(id_token["iss"]).netloc, # "iss" is the URL of the issuing tenant |
| 62 | + client_id=id_token["aud"], |
| 63 | + home_account_id=home_account_id, |
| 64 | + tenant_id=id_token["tid"], # tenant which issued the token, not necessarily user's home tenant |
| 65 | + username=id_token["preferred_username"], |
| 66 | + ) |
| 67 | + except (KeyError, ValueError): |
| 68 | + # surprising: msal.ClientApplication always requests an id token, whose shape shouldn't change |
| 69 | + return None |
| 70 | + |
| 71 | + |
| 72 | +class InteractiveCredential(MsalCredential): |
| 73 | + def __init__(self, **kwargs): |
| 74 | + self._disable_automatic_authentication = kwargs.pop("disable_automatic_authentication", False) |
| 75 | + self._auth_record = kwargs.pop("authentication_record", None) # type: Optional[AuthenticationRecord] |
| 76 | + if self._auth_record: |
| 77 | + kwargs.pop("client_id", None) # authentication_record overrides client_id argument |
| 78 | + tenant_id = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id |
| 79 | + super(InteractiveCredential, self).__init__( |
| 80 | + client_id=self._auth_record.client_id, |
| 81 | + authority=self._auth_record.authority, |
| 82 | + tenant_id=tenant_id, |
| 83 | + **kwargs |
| 84 | + ) |
| 85 | + else: |
| 86 | + super(InteractiveCredential, self).__init__(**kwargs) |
| 87 | + |
| 88 | + def get_token(self, *scopes, **kwargs): |
| 89 | + # type: (*str, **Any) -> AccessToken |
| 90 | + """Request an access token for `scopes`. |
| 91 | +
|
| 92 | + .. note:: This method is called by Azure SDK clients. It isn't intended for use in application code. |
| 93 | +
|
| 94 | + :param str scopes: desired scopes for the access token. This method requires at least one scope. |
| 95 | + :rtype: :class:`azure.core.credentials.AccessToken` |
| 96 | + :raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks |
| 97 | + required data, state, or platform support |
| 98 | + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` |
| 99 | + attribute gives a reason. |
| 100 | + :raises AuthenticationRequiredError: user interaction is necessary to acquire a token, and the credential is |
| 101 | + configured not to begin this automatically. Call :func:`authenticate` to begin interactive authentication. |
| 102 | + """ |
| 103 | + if not scopes: |
| 104 | + message = "'get_token' requires at least one scope" |
| 105 | + _LOGGER.warning("%s.get_token failed: %s", self.__class__.__name__, message) |
| 106 | + raise ValueError(message) |
| 107 | + |
| 108 | + allow_prompt = kwargs.pop("_allow_prompt", not self._disable_automatic_authentication) |
| 109 | + try: |
| 110 | + token = self._acquire_token_silent(*scopes, **kwargs) |
| 111 | + _LOGGER.info("%s.get_token succeeded", self.__class__.__name__) |
| 112 | + return token |
| 113 | + except Exception as ex: # pylint:disable=broad-except |
| 114 | + if not (isinstance(ex, AuthenticationRequiredError) and allow_prompt): |
| 115 | + _LOGGER.warning( |
| 116 | + "%s.get_token failed: %s", |
| 117 | + self.__class__.__name__, |
| 118 | + ex, |
| 119 | + exc_info=_LOGGER.isEnabledFor(logging.DEBUG), |
| 120 | + ) |
| 121 | + raise |
| 122 | + |
| 123 | + # silent authentication failed -> authenticate interactively |
| 124 | + now = int(time.time()) |
| 125 | + |
| 126 | + try: |
| 127 | + result = self._request_token(*scopes, **kwargs) |
| 128 | + if "access_token" not in result: |
| 129 | + message = "Authentication failed: {}".format(result.get("error_description") or result.get("error")) |
| 130 | + raise ClientAuthenticationError(message=message) |
| 131 | + |
| 132 | + # this may be the first authentication, or the user may have authenticated a different identity |
| 133 | + self._auth_record = _build_auth_record(result) |
| 134 | + except Exception as ex: # pylint:disable=broad-except |
| 135 | + _LOGGER.warning( |
| 136 | + "%s.get_token failed: %s", self.__class__.__name__, ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG), |
| 137 | + ) |
| 138 | + raise |
| 139 | + |
| 140 | + _LOGGER.info("%s.get_token succeeded", self.__class__.__name__) |
| 141 | + return AccessToken(result["access_token"], now + int(result["expires_in"])) |
| 142 | + |
| 143 | + def authenticate(self, **kwargs): |
| 144 | + # type: (**Any) -> AuthenticationRecord |
| 145 | + """Interactively authenticate a user. |
| 146 | +
|
| 147 | + :keyword Iterable[str] scopes: scopes to request during authentication, such as those provided by |
| 148 | + :func:`AuthenticationRequiredError.scopes`. If provided, successful authentication will cache an access token |
| 149 | + for these scopes. |
| 150 | + :rtype: ~azure.identity.AuthenticationRecord |
| 151 | + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` |
| 152 | + attribute gives a reason. |
| 153 | + """ |
| 154 | + |
| 155 | + scopes = kwargs.pop("scopes", None) |
| 156 | + if not scopes: |
| 157 | + if self._authority not in _DEFAULT_AUTHENTICATE_SCOPES: |
| 158 | + # the credential is configured to use a cloud whose ARM scope we can't determine |
| 159 | + raise CredentialUnavailableError( |
| 160 | + message="Authenticating in this environment requires a value for the 'scopes' keyword argument." |
| 161 | + ) |
| 162 | + |
| 163 | + scopes = _DEFAULT_AUTHENTICATE_SCOPES[self._authority] |
| 164 | + |
| 165 | + _ = self.get_token(*scopes, _allow_prompt=True, **kwargs) |
| 166 | + return self._auth_record # type: ignore |
| 167 | + |
| 168 | + @wrap_exceptions |
| 169 | + def _acquire_token_silent(self, *scopes, **kwargs): |
| 170 | + # type: (*str, **Any) -> AccessToken |
| 171 | + result = None |
| 172 | + if self._auth_record: |
| 173 | + app = self._get_app() |
| 174 | + for account in app.get_accounts(username=self._auth_record.username): |
| 175 | + if account.get("home_account_id") != self._auth_record.home_account_id: |
| 176 | + continue |
| 177 | + |
| 178 | + now = int(time.time()) |
| 179 | + result = app.acquire_token_silent_with_error(list(scopes), account=account, **kwargs) |
| 180 | + if result and "access_token" in result and "expires_in" in result: |
| 181 | + return AccessToken(result["access_token"], now + int(result["expires_in"])) |
| 182 | + |
| 183 | + # if we get this far, result is either None or the content of an AAD error response |
| 184 | + if result: |
| 185 | + details = result.get("error_description") or result.get("error") |
| 186 | + raise AuthenticationRequiredError(scopes, error_details=details) |
| 187 | + raise AuthenticationRequiredError(scopes) |
| 188 | + |
| 189 | + def _get_app(self): |
| 190 | + # type: () -> msal.PublicClientApplication |
| 191 | + if not self._msal_app: |
| 192 | + self._msal_app = self._create_app(msal.PublicClientApplication) |
| 193 | + return self._msal_app |
| 194 | + |
| 195 | + @abc.abstractmethod |
| 196 | + def _request_token(self, *scopes, **kwargs): |
| 197 | + pass |
0 commit comments