Skip to content

Commit c33242f

Browse files
committed
CertificateCredential uses AadClient
1 parent 442f164 commit c33242f

File tree

6 files changed

+62
-69
lines changed

6 files changed

+62
-69
lines changed

sdk/identity/azure-identity/azure/identity/_base.py

-52
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,6 @@
33
# Licensed under the MIT License.
44
# ------------------------------------
55
import abc
6-
import binascii
7-
8-
from cryptography import x509
9-
from cryptography.hazmat.primitives import hashes, serialization
10-
from cryptography.hazmat.backends import default_backend
11-
from msal.oauth2cli import JwtSigner
12-
import six
136

147
try:
158
ABC = abc.ABC
@@ -41,48 +34,3 @@ def __init__(self, tenant_id, client_id, secret, **kwargs): # pylint:disable=un
4134
)
4235
self._form_data = {"client_id": client_id, "client_secret": secret, "grant_type": "client_credentials"}
4336
super(ClientSecretCredentialBase, self).__init__()
44-
45-
46-
class CertificateCredentialBase(ABC):
47-
"""Sans I/O base for certificate credentials"""
48-
49-
def __init__(self, tenant_id, client_id, certificate_path, **kwargs): # pylint:disable=unused-argument
50-
# type: (str, str, str, **Any) -> None
51-
if not certificate_path:
52-
raise ValueError(
53-
"'certificate_path' must be the path to a PEM file containing an x509 certificate and its private key"
54-
)
55-
56-
super(CertificateCredentialBase, self).__init__()
57-
58-
password = kwargs.pop("password", None)
59-
if isinstance(password, six.text_type):
60-
password = password.encode(encoding="utf-8")
61-
62-
with open(certificate_path, "rb") as f:
63-
pem_bytes = f.read()
64-
65-
private_key = serialization.load_pem_private_key(pem_bytes, password=password, backend=default_backend())
66-
cert = x509.load_pem_x509_certificate(pem_bytes, default_backend())
67-
fingerprint = cert.fingerprint(hashes.SHA1()) #nosec
68-
69-
self._client = self._get_auth_client(tenant_id, **kwargs)
70-
self._client_id = client_id
71-
self._signer = JwtSigner(private_key, "RS256", sha1_thumbprint=binascii.hexlify(fingerprint))
72-
73-
def _get_request_data(self, *scopes):
74-
assertion = self._signer.sign_assertion(audience=self._client.auth_url, issuer=self._client_id)
75-
if isinstance(assertion, six.binary_type):
76-
assertion = assertion.decode("utf-8")
77-
78-
return {
79-
"client_assertion": assertion,
80-
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
81-
"client_id": self._client_id,
82-
"grant_type": "client_credentials",
83-
"scope": " ".join(scopes),
84-
}
85-
86-
@abc.abstractmethod
87-
def _get_auth_client(self, tenant_id, **kwargs):
88-
pass

sdk/identity/azure-identity/azure/identity/_credentials/certificate.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
# ------------------------------------
55
from typing import TYPE_CHECKING
66

7-
from .._authn_client import AuthnClient
8-
from .._base import CertificateCredentialBase
7+
from .._internal import AadClient, CertificateCredentialBase
98

109
if TYPE_CHECKING:
1110
from azure.core.credentials import AccessToken
@@ -42,11 +41,10 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
4241
if not scopes:
4342
raise ValueError("'get_token' requires at least one scope")
4443

45-
token = self._client.get_cached_token(scopes)
44+
token = self._client.get_cached_access_token(scopes)
4645
if not token:
47-
data = self._get_request_data(*scopes)
48-
token = self._client.request_token(scopes, form_data=data)
46+
token = self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
4947
return token
5048

51-
def _get_auth_client(self, tenant_id, **kwargs):
52-
return AuthnClient(tenant=tenant_id, **kwargs)
49+
def _get_auth_client(self, tenant_id, client_id, **kwargs):
50+
return AadClient(tenant_id, client_id, **kwargs)

sdk/identity/azure-identity/azure/identity/_internal/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def get_default_authority():
3434
from .aad_client_base import AadClientBase
3535
from .auth_code_redirect_handler import AuthCodeRedirectServer
3636
from .aadclient_certificate import AadClientCertificate
37+
from .certificate_credential_base import CertificateCredentialBase
3738
from .exception_wrapper import wrap_exceptions
3839
from .msal_credentials import ConfidentialClientCredential, InteractiveCredential, PublicClientCredential
3940
from .msal_transport_adapter import MsalTransportAdapter, MsalTransportResponse
@@ -58,6 +59,7 @@ def _scopes_to_resource(*scopes):
5859
"AadClientBase",
5960
"AuthCodeRedirectServer",
6061
"AadClientCertificate",
62+
"CertificateCredentialBase",
6163
"ConfidentialClientCredential",
6264
"get_default_authority",
6365
"InteractiveCredential",

sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,8 @@ def _get_client_certificate_request(self, scopes, certificate):
127127
}
128128

129129
request = HttpRequest(
130-
"POST", self._token_endpoint, headers={"Content-Type": "application/x-www-form-urlencoded"}
130+
"POST", self._token_endpoint, headers={"Content-Type": "application/x-www-form-urlencoded"}, data=data
131131
)
132-
request.set_formdata_body(data)
133132
return request
134133

135134
def _get_jwt_assertion(self, certificate):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
import abc
6+
7+
import six
8+
from azure.identity._internal import AadClientCertificate
9+
10+
try:
11+
ABC = abc.ABC
12+
except AttributeError: # Python 2.7, abc exists, but not ABC
13+
ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()}) # type: ignore
14+
15+
try:
16+
from typing import TYPE_CHECKING
17+
except ImportError:
18+
TYPE_CHECKING = False
19+
20+
if TYPE_CHECKING:
21+
# pylint:disable=unused-import
22+
from typing import Any
23+
24+
25+
class CertificateCredentialBase(ABC):
26+
def __init__(self, tenant_id, client_id, certificate_path, **kwargs):
27+
# type: (str, str, str, **Any) -> None
28+
if not certificate_path:
29+
raise ValueError(
30+
"'certificate_path' must be the path to a PEM file containing an x509 certificate and its private key"
31+
)
32+
33+
super(CertificateCredentialBase, self).__init__()
34+
35+
password = kwargs.pop("password", None)
36+
if isinstance(password, six.text_type):
37+
password = password.encode(encoding="utf-8")
38+
39+
with open(certificate_path, "rb") as f:
40+
pem_bytes = f.read()
41+
42+
self._certificate = AadClientCertificate(pem_bytes, password=password)
43+
self._client = self._get_auth_client(tenant_id, client_id, **kwargs)
44+
45+
@abc.abstractmethod
46+
def _get_auth_client(self, tenant_id, client_id, **kwargs):
47+
pass

sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from typing import TYPE_CHECKING
66

77
from .base import AsyncCredentialBase
8-
from .._authn_client import AsyncAuthnClient
9-
from ..._base import CertificateCredentialBase
8+
from .._internal import AadClient
9+
from ..._internal import CertificateCredentialBase
1010

1111
if TYPE_CHECKING:
1212
from typing import Any
@@ -51,11 +51,10 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py
5151
if not scopes:
5252
raise ValueError("'get_token' requires at least one scope")
5353

54-
token = self._client.get_cached_token(scopes)
54+
token = self._client.get_cached_access_token(scopes)
5555
if not token:
56-
data = self._get_request_data(*scopes)
57-
token = await self._client.request_token(scopes, form_data=data)
58-
return token # type: ignore
56+
token = await self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
57+
return token
5958

60-
def _get_auth_client(self, tenant_id, **kwargs):
61-
return AsyncAuthnClient(tenant=tenant_id, **kwargs)
59+
def _get_auth_client(self, tenant_id, client_id, **kwargs):
60+
return AadClient(tenant_id, client_id, **kwargs)

0 commit comments

Comments
 (0)