Skip to content

Commit c8991f7

Browse files
authored
azure-identity architecture board feedback (#5728)
1 parent 7df999b commit c8991f7

18 files changed

+476
-186
lines changed

sdk/core/azure-core/azure/core/credentials.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,15 @@
33
# Licensed under the MIT License. See LICENSE.txt in the project root for
44
# license information.
55
# --------------------------------------------------------------------------
6-
from typing import Iterable # pylint:disable=unused-import
76
from typing_extensions import Protocol
87

98

10-
class SupportsGetToken(Protocol):
9+
class TokenCredential(Protocol):
1110
"""Protocol for classes able to provide OAuth tokens.
1211
1312
:param str scopes: Lets you specify the type of access needed.
1413
"""
1514
# pylint:disable=too-few-public-methods
16-
def get_token(self, scopes):
17-
# type: (Iterable[str]) -> str
15+
def get_token(self, *scopes):
16+
# type: (*str) -> str
1817
pass

sdk/core/azure-core/azure/core/pipeline/policies/credentials.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
if TYPE_CHECKING:
1414
# pylint:disable=unused-import
1515
from typing import Any, Dict, Iterable, Mapping
16-
from azure.core.credentials import SupportsGetToken
16+
from azure.core.credentials import TokenCredential
1717
from azure.core.pipeline import PipelineRequest, PipelineResponse
1818

1919

@@ -22,12 +22,12 @@ class _BearerTokenCredentialPolicyBase(object):
2222
"""Base class for a Bearer Token Credential Policy.
2323
2424
:param credential: The credential.
25-
:type credential: ~azure.core.SupportsGetToken
25+
:type credential: ~azure.core.credentials.TokenCredential
2626
:param str scopes: Lets you specify the type of access needed.
2727
"""
2828

2929
def __init__(self, credential, scopes, **kwargs): # pylint:disable=unused-argument
30-
# type: (SupportsGetToken, Iterable[str], Mapping[str, Any]) -> None
30+
# type: (TokenCredential, Iterable[str], Mapping[str, Any]) -> None
3131
super(_BearerTokenCredentialPolicyBase, self).__init__()
3232
self._scopes = scopes
3333
self._credential = credential
@@ -47,7 +47,7 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy):
4747
"""Adds a bearer token Authorization header to requests.
4848
4949
:param credential: The credential.
50-
:type credential: ~azure.core.SupportsGetToken
50+
:type credential: ~azure.core.TokenCredential
5151
:param str scopes: Lets you specify the type of access needed.
5252
"""
5353

sdk/core/azure-core/azure/core/pipeline/policies/credentials_async.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class AsyncBearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, AsyncHT
1313
"""Adds a bearer token Authorization header to requests.
1414
1515
:param credential: The credential.
16-
:type credential: ~azure.core.SupportsGetToken
16+
:type credential: ~azure.core.credentials.TokenCredential
1717
:param str scopes: Lets you specify the type of access needed.
1818
"""
1919

sdk/identity/azure-identity/azure/identity/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,21 @@
1212
TokenCredentialChain,
1313
)
1414

15+
16+
class DefaultAzureCredential(TokenCredentialChain):
17+
"""default credential is environment followed by MSI/IMDS"""
18+
19+
def __init__(self, **kwargs):
20+
super(DefaultAzureCredential, self).__init__(
21+
EnvironmentCredential(**kwargs), ManagedIdentityCredential(**kwargs)
22+
)
23+
24+
1525
__all__ = [
1626
"AuthenticationError",
1727
"CertificateCredential",
1828
"ClientSecretCredential",
29+
"DefaultAzureCredential",
1930
"EnvironmentCredential",
2031
"ManagedIdentityCredential",
2132
"TokenCredentialChain",
@@ -25,6 +36,7 @@
2536
from .aio import (
2637
AsyncCertificateCredential,
2738
AsyncClientSecretCredential,
39+
AsyncDefaultAzureCredential,
2840
AsyncEnvironmentCredential,
2941
AsyncManagedIdentityCredential,
3042
AsyncTokenCredentialChain,
@@ -34,10 +46,11 @@
3446
[
3547
"AsyncCertificateCredential",
3648
"AsyncClientSecretCredential",
49+
"AsyncDefaultAzureCredential",
3750
"AsyncEnvironmentCredential",
3851
"AsyncManagedIdentityCredential",
3952
"AsyncTokenCredentialChain",
4053
]
4154
)
42-
except SyntaxError:
55+
except (ImportError, SyntaxError):
4356
pass

sdk/identity/azure-identity/azure/identity/_authn_client.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# Licensed under the MIT License. See LICENSE.txt in the project root for
44
# license information.
5-
# --------------------------------------------------------------------------
5+
# -------------------------------------------------------------------------
66
from time import time
77

88
from azure.core import Configuration, HttpRequest
@@ -51,16 +51,24 @@ def _deserialize_and_cache_token(self, response, scopes):
5151
else:
5252
payload = response.http_response.text()
5353
token = payload["access_token"]
54+
55+
# these values are strings in IMDS responses but msal.TokenCache requires they be integers
56+
# https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/55
57+
if payload.get("expires_in"):
58+
payload["expires_in"] = int(payload["expires_in"])
59+
if payload.get("ext_expires_in"):
60+
payload["ext_expires_in"] = int(payload["ext_expires_in"])
61+
5462
self._cache.add({"response": payload, "scope": scopes})
5563
return token
5664
except KeyError:
5765
raise AuthenticationError("Unexpected authentication response: {}".format(payload))
5866
except Exception as ex:
5967
raise AuthenticationError("Authentication failed: {}".format(str(ex)))
6068

61-
def _prepare_request(self, method="POST", form_data=None, params=None):
62-
# type: (Optional[str], Optional[Mapping[str, str]], Optional[Dict[str, str]]) -> HttpRequest
63-
request = HttpRequest(method, self._auth_url)
69+
def _prepare_request(self, method="POST", headers=None, form_data=None, params=None):
70+
# type: (Optional[str], Optional[Mapping[str, str]], Optional[Mapping[str, str]], Optional[Dict[str, str]]) -> HttpRequest
71+
request = HttpRequest(method, self._auth_url, headers=headers)
6472
if form_data:
6573
request.headers["Content-Type"] = "application/x-www-form-urlencoded"
6674
request.set_formdata_body(form_data)
@@ -81,9 +89,9 @@ def __init__(self, auth_url, config=None, policies=None, transport=None, **kwarg
8189
self._pipeline = Pipeline(transport=transport, policies=policies)
8290
super(AuthnClient, self).__init__(auth_url, **kwargs)
8391

84-
def request_token(self, scopes, method="POST", form_data=None, params=None):
85-
# type: (Iterable[str], Optional[str], Optional[Mapping[str, str]], Optional[Dict[str, str]]) -> str
86-
request = self._prepare_request(method, form_data, params)
92+
def request_token(self, scopes, method="POST", headers=None, form_data=None, params=None):
93+
# type: (Iterable[str], Optional[str], Optional[Mapping[str, str]], Optional[Mapping[str, str]], Optional[Dict[str, str]]) -> str
94+
request = self._prepare_request(method, headers=headers, form_data=form_data, params=params)
8795
response = self._pipeline.run(request, stream=False)
8896
token = self._deserialize_and_cache_token(response, scopes)
8997
return token

sdk/identity/azure-identity/azure/identity/_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# --------------------------------------------------------------------------
66
from msal.oauth2cli import JwtSigner
77

8-
from .constants import OAUTH_ENDPOINT
8+
from .constants import Endpoints
99

1010
try:
1111
from typing import TYPE_CHECKING
@@ -38,7 +38,7 @@ def __init__(self, client_id, tenant_id, certificate_path, **kwargs):
3838
raise ValueError("certificate_path must be the path to a PEM-encoded private key file")
3939

4040
super(CertificateCredentialBase, self).__init__()
41-
auth_url = OAUTH_ENDPOINT.format(tenant_id)
41+
auth_url = Endpoints.AAD_OAUTH2_V2_FORMAT.format(tenant_id)
4242

4343
with open(certificate_path) as pem:
4444
private_key = pem.read()
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# ------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See LICENSE.txt in the project root for
4+
# license information.
5+
# ------------------------------------------------------------------------
6+
import os
7+
8+
try:
9+
from typing import TYPE_CHECKING
10+
except ImportError:
11+
TYPE_CHECKING = False
12+
13+
if TYPE_CHECKING:
14+
# pylint:disable=unused-import
15+
from typing import Any, Dict, Optional
16+
17+
from azure.core import Configuration
18+
from azure.core.pipeline.policies import ContentDecodePolicy, HeadersPolicy, NetworkTraceLoggingPolicy, RetryPolicy
19+
20+
from ._authn_client import AuthnClient
21+
from .constants import Endpoints, MSI_ENDPOINT, MSI_SECRET
22+
from .exceptions import AuthenticationError
23+
24+
25+
class ImdsCredential:
26+
"""Authenticates with a managed identity via the IMDS endpoint"""
27+
28+
def __init__(self, config=None, **kwargs):
29+
# type: (Optional[Configuration], Dict[str, Any]) -> None
30+
config = config or self.create_config(**kwargs)
31+
policies = [config.header_policy, ContentDecodePolicy(), config.logging_policy, config.retry_policy]
32+
self._client = AuthnClient(Endpoints.IMDS, config, policies, **kwargs)
33+
34+
@staticmethod
35+
def create_config(**kwargs):
36+
# type: (Dict[str, str]) -> Configuration
37+
timeout = kwargs.pop("connection_timeout", 2)
38+
config = Configuration(connection_timeout=timeout, **kwargs)
39+
config.header_policy = HeadersPolicy(base_headers={"Metadata": "true"}, **kwargs)
40+
config.logging_policy = NetworkTraceLoggingPolicy(**kwargs)
41+
retries = kwargs.pop("retry_total", 5)
42+
config.retry_policy = RetryPolicy(
43+
retry_total=retries, retry_on_status_codes=[404, 429] + list(range(500, 600)), **kwargs
44+
)
45+
return config
46+
47+
def get_token(self, *scopes):
48+
# type: (*str) -> str
49+
if len(scopes) != 1:
50+
raise ValueError("this credential supports one scope per request")
51+
token = self._client.get_cached_token(scopes)
52+
if not token:
53+
resource = scopes[0]
54+
if resource.endswith("/.default"):
55+
resource = resource[:-len("/.default")]
56+
token = self._client.request_token(
57+
scopes, method="GET", params={"api-version": "2018-02-01", "resource": resource}
58+
)
59+
return token
60+
61+
62+
class MsiCredential:
63+
"""Authenticates via the MSI endpoint"""
64+
65+
def __init__(self, config=None, **kwargs):
66+
# type: (Optional[Configuration], Dict[str, Any]) -> None
67+
config = config or self.create_config(**kwargs)
68+
policies = [ContentDecodePolicy(), config.retry_policy, config.logging_policy]
69+
endpoint = os.environ.get(MSI_ENDPOINT)
70+
if not endpoint:
71+
raise ValueError("expected environment variable {} has no value".format(MSI_ENDPOINT))
72+
self._client = AuthnClient(endpoint, config, policies, **kwargs)
73+
74+
@staticmethod
75+
def create_config(**kwargs):
76+
# type: (Dict[str, str]) -> Configuration
77+
timeout = kwargs.pop("connection_timeout", 2)
78+
config = Configuration(connection_timeout=timeout, **kwargs)
79+
config.logging_policy = NetworkTraceLoggingPolicy(**kwargs)
80+
retries = kwargs.pop("retry_total", 5)
81+
config.retry_policy = RetryPolicy(
82+
retry_total=retries, retry_on_status_codes=[404, 429] + list(range(500, 600)), **kwargs
83+
)
84+
return config
85+
86+
def get_token(self, *scopes):
87+
# type: (*str) -> str
88+
if len(scopes) != 1:
89+
raise ValueError("this credential supports only one scope per request")
90+
token = self._client.get_cached_token(scopes)
91+
if not token:
92+
secret = os.environ.get(MSI_SECRET)
93+
if not secret:
94+
raise AuthenticationError("{} environment variable has no value".format(MSI_SECRET))
95+
resource = scopes[0]
96+
if resource.endswith("/.default"):
97+
resource = resource[:-len("/.default")]
98+
# TODO: support user-assigned client id
99+
token = self._client.request_token(
100+
scopes,
101+
method="GET",
102+
headers={"secret": secret},
103+
params={"api-version": "2017-09-01", "resource": resource},
104+
)
105+
return token

sdk/identity/azure-identity/azure/identity/aio/__init__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
# -------------------------------------------------------------------------
1+
# ------------------------------------------------------------------------
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# Licensed under the MIT License. See LICENSE.txt in the project root for
44
# license information.
5-
# --------------------------------------------------------------------------
5+
# ------------------------------------------------------------------------
66
from .credentials import (
77
AsyncCertificateCredential,
88
AsyncClientSecretCredential,
@@ -11,9 +11,18 @@
1111
AsyncTokenCredentialChain,
1212
)
1313

14+
15+
class AsyncDefaultAzureCredential(AsyncTokenCredentialChain):
16+
"""default credential is environment followed by MSI/IMDS"""
17+
18+
def __init__(self, **kwargs):
19+
super().__init__(AsyncEnvironmentCredential(**kwargs), AsyncManagedIdentityCredential(**kwargs))
20+
21+
1422
__all__ = [
1523
"AsyncCertificateCredential",
1624
"AsyncClientSecretCredential",
25+
"AsyncDefaultAzureCredential",
1726
"AsyncEnvironmentCredential",
1827
"AsyncManagedIdentityCredential",
1928
"AsyncTokenCredentialChain",

sdk/identity/azure-identity/azure/identity/aio/_authn_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,11 @@ async def request_token(
3636
self,
3737
scopes: Iterable[str],
3838
method: Optional[str] = "POST",
39+
headers: Optional[Mapping[str, str]] = None,
3940
form_data: Optional[Mapping[str, str]] = None,
4041
params: Optional[Dict[str, str]] = None,
4142
) -> str:
42-
request = self._prepare_request(method, form_data, params)
43+
request = self._prepare_request(method, headers=headers, form_data=form_data, params=params)
4344
response = await self._pipeline.run(request, stream=False)
4445
token = self._deserialize_and_cache_token(response, scopes)
4546
return token

0 commit comments

Comments
 (0)