Skip to content

Commit b17b5ec

Browse files
committed
Merge branch 'master' of https://github.com/Azure/azure-sdk-for-python into fix_annotation_initial_response
* 'master' of https://github.com/Azure/azure-sdk-for-python: Adding digital twins CI configuration. (Azure#11730) Sync eng/common directory with azure-sdk-tools repository (Azure#11692) Reimplement AadClient without msal.oauth2cli (Azure#11466)
2 parents 152f8a8 + 7bdf96e commit b17b5ec

20 files changed

+568
-291
lines changed

eng/common/scripts/git-branch-push.ps1

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,16 @@ do
121121
continue
122122
}
123123

124-
Write-Host "git -c user.name=`"azure-sdk`" -c user.email=`"[email protected]`" commit -am `"$($CommitMsg)`""
125-
git -c user.name="azure-sdk" -c user.email="[email protected]" commit -am "$($CommitMsg)"
124+
Write-Host "git add -A"
125+
git add -A
126+
if ($LASTEXITCODE -ne 0)
127+
{
128+
Write-Error "Unable to git add LASTEXITCODE=$($LASTEXITCODE), see command output above."
129+
continue
130+
}
131+
132+
Write-Host "git -c user.name=`"azure-sdk`" -c user.email=`"[email protected]`" commit -m `"$($CommitMsg)`""
133+
git -c user.name="azure-sdk" -c user.email="[email protected]" commit -m "$($CommitMsg)"
126134
if ($LASTEXITCODE -ne 0)
127135
{
128136
Write-Error "Unable to commit LASTEXITCODE=$($LASTEXITCODE), see command output above."

sdk/digitaltwins/ci.yml

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# DO NOT EDIT THIS FILE
2+
# This file is generated automatically and any changes will be lost.
3+
4+
resources:
5+
repositories:
6+
- repository: azure-sdk-tools
7+
type: github
8+
name: Azure/azure-sdk-tools
9+
endpoint: azure
10+
- repository: azure-sdk-build-tools
11+
type: git
12+
name: internal/azure-sdk-build-tools
13+
14+
trigger:
15+
branches:
16+
include:
17+
- master
18+
- hotfix/*
19+
- release/*
20+
- restapi*
21+
paths:
22+
include:
23+
- sdk/digitaltwins/
24+
25+
pr:
26+
branches:
27+
include:
28+
- master
29+
- feature/*
30+
- hotfix/*
31+
- release/*
32+
- restapi*
33+
paths:
34+
include:
35+
- sdk/digitaltwins/
36+
37+
stages:
38+
- template: ../../eng/pipelines/templates/stages/archetype-sdk-client.yml
39+
parameters:
40+
ServiceDirectory: digitaltwins
41+
Artifacts:
42+
- name: azure_mgmt_digitaltwins
43+
safeName: azuremgmtdigitaltwins

sdk/identity/azure-identity/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# Release History
22

33
## 1.4.0b4 (Unreleased)
4+
- `azure.identity.aio.AuthorizationCodeCredential.get_token()` no longer accepts
5+
optional keyword arguments `executor` or `loop`. Prior versions of the method
6+
didn't use these correctly, provoking exceptions, and internal changes in this
7+
version have made them obsolete.
48
- `InteractiveBrowserCredential` raises `CredentialUnavailableError` when it
59
can't start an HTTP server on `localhost`.
610
([#11665](https://github.com/Azure/azure-sdk-for-python/pull/11665))

sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
if TYPE_CHECKING:
1111
# pylint:disable=unused-import,ungrouped-imports
12-
from typing import Any, Iterable, Optional
12+
from typing import Any, Optional, Sequence
1313
from azure.core.credentials import AccessToken
1414

1515

@@ -59,7 +59,7 @@ def get_token(self, *scopes, **kwargs):
5959

6060
if self._authorization_code:
6161
token = self._client.obtain_token_by_authorization_code(
62-
code=self._authorization_code, redirect_uri=self._redirect_uri, scopes=scopes, **kwargs
62+
scopes=scopes, code=self._authorization_code, redirect_uri=self._redirect_uri, **kwargs
6363
)
6464
self._authorization_code = None # auth codes are single-use
6565
return token
@@ -73,9 +73,11 @@ def get_token(self, *scopes, **kwargs):
7373
return token
7474

7575
def _redeem_refresh_token(self, scopes, **kwargs):
76-
# type: (Iterable[str], **Any) -> Optional[AccessToken]
76+
# type: (Sequence[str], **Any) -> Optional[AccessToken]
7777
for refresh_token in self._client.get_cached_refresh_tokens(scopes):
78-
token = self._client.obtain_token_by_refresh_token(refresh_token, scopes, **kwargs)
78+
if "secret" not in refresh_token:
79+
continue
80+
token = self._client.obtain_token_by_refresh_token(scopes, refresh_token["secret"], **kwargs)
7981
if token:
8082
return token
8183
return None

sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
6060

6161
# try each refresh token, returning the first access token acquired
6262
for refresh_token in self._get_refresh_tokens(account):
63-
token = self._client.obtain_token_by_refresh_token(refresh_token, scopes)
63+
token = self._client.obtain_token_by_refresh_token(scopes, refresh_token)
6464
return token
6565

6666
raise CredentialUnavailableError(message=NO_TOKEN.format(account.get("username")))

sdk/identity/azure-identity/azure/identity/_credentials/vscode_credential.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,5 @@ def get_token(self, *scopes, **kwargs):
5656
if not self._refresh_token:
5757
raise CredentialUnavailableError(message="No Azure user is logged in to Visual Studio Code.")
5858

59-
token = self._client.obtain_token_by_refresh_token(self._refresh_token, scopes, **kwargs)
59+
token = self._client.obtain_token_by_refresh_token(scopes, self._refresh_token, **kwargs)
6060
return token

sdk/identity/azure-identity/azure/identity/_internal/aad_client.py

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,78 @@
22
# Copyright (c) Microsoft Corporation.
33
# Licensed under the MIT License.
44
# ------------------------------------
5-
"""A thin wrapper around MSAL's token cache and OAuth 2 client"""
6-
75
import time
86
from typing import TYPE_CHECKING
97

10-
from azure.core.credentials import AccessToken
8+
from azure.core.configuration import Configuration
9+
from azure.core.pipeline import Pipeline
10+
from azure.core.pipeline.policies import (
11+
NetworkTraceLoggingPolicy,
12+
RetryPolicy,
13+
ProxyPolicy,
14+
UserAgentPolicy,
15+
ContentDecodePolicy,
16+
DistributedTracingPolicy,
17+
HttpLoggingPolicy,
18+
)
1119

1220
from .aad_client_base import AadClientBase
13-
from .msal_transport_adapter import MsalTransportAdapter
14-
from .exception_wrapper import wrap_exceptions
21+
from .user_agent import USER_AGENT
1522

1623
if TYPE_CHECKING:
1724
# pylint:disable=unused-import,ungrouped-imports
18-
from typing import Any, Callable, Iterable
25+
from typing import Any, List, Optional, Sequence, Union
26+
from azure.core.credentials import AccessToken
27+
from azure.core.pipeline.policies import HTTPPolicy, SansIOHTTPPolicy
28+
from azure.core.pipeline.transport import HttpTransport
29+
30+
Policy = Union[HTTPPolicy, SansIOHTTPPolicy]
1931

2032

2133
class AadClient(AadClientBase):
22-
def _get_client_session(self, **kwargs):
23-
return MsalTransportAdapter(**kwargs)
34+
def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs):
35+
# type: (str, str, Sequence[str], Optional[str], **Any) -> AccessToken
36+
request = self._get_auth_code_request(
37+
scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret
38+
)
39+
now = int(time.time())
40+
response = self._pipeline.run(request, stream=False, **kwargs)
41+
content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response)
42+
return self._process_response(response=content, scopes=scopes, now=now)
2443

25-
@wrap_exceptions
26-
def _obtain_token(self, scopes, fn, **kwargs): # pylint:disable=unused-argument
27-
# type: (Iterable[str], Callable, **Any) -> AccessToken
44+
def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs):
45+
# type: (str, Sequence[str], **Any) -> AccessToken
46+
request = self._get_refresh_token_request(scopes, refresh_token)
2847
now = int(time.time())
29-
response = fn()
30-
return self._process_response(response=response, scopes=scopes, now=now)
48+
response = self._pipeline.run(request, stream=False, **kwargs)
49+
content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response)
50+
return self._process_response(response=content, scopes=scopes, now=now)
51+
52+
# pylint:disable=no-self-use
53+
def _build_pipeline(self, config=None, policies=None, transport=None, **kwargs):
54+
# type: (Optional[Configuration], Optional[List[Policy]], Optional[HttpTransport], **Any) -> Pipeline
55+
config = config or _create_config(**kwargs)
56+
policies = policies or [
57+
config.user_agent_policy,
58+
config.proxy_policy,
59+
config.retry_policy,
60+
config.logging_policy,
61+
DistributedTracingPolicy(**kwargs),
62+
HttpLoggingPolicy(**kwargs),
63+
]
64+
if not transport:
65+
from azure.core.pipeline.transport import RequestsTransport
66+
67+
transport = RequestsTransport(**kwargs)
68+
69+
return Pipeline(transport=transport, policies=policies)
70+
71+
72+
def _create_config(**kwargs):
73+
# type: (**Any) -> Configuration
74+
config = Configuration(**kwargs)
75+
config.logging_policy = NetworkTraceLoggingPolicy(**kwargs)
76+
config.retry_policy = RetryPolicy(**kwargs)
77+
config.proxy_policy = ProxyPolicy(**kwargs)
78+
config.user_agent_policy = UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs)
79+
return config

sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py

Lines changed: 66 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -4,50 +4,48 @@
44
# ------------------------------------
55
import abc
66
import copy
7-
import functools
87
import time
98

10-
try:
11-
from typing import TYPE_CHECKING
12-
except ImportError:
13-
TYPE_CHECKING = False
14-
159
from msal import TokenCache
16-
from msal.oauth2cli.oauth2 import Client
1710

11+
from azure.core.pipeline.transport import HttpRequest
1812
from azure.core.credentials import AccessToken
1913
from azure.core.exceptions import ClientAuthenticationError
2014
from . import get_default_authority, normalize_authority
2115

16+
try:
17+
from typing import TYPE_CHECKING
18+
except ImportError:
19+
TYPE_CHECKING = False
20+
2221
try:
2322
ABC = abc.ABC
2423
except AttributeError: # Python 2.7, abc exists, but not ABC
2524
ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()}) # type: ignore
2625

2726
if TYPE_CHECKING:
2827
# pylint:disable=unused-import,ungrouped-imports
29-
from typing import Any, Callable, Iterable, Optional
28+
from typing import Any, Optional, Sequence, Union
29+
from azure.core.pipeline import AsyncPipeline, Pipeline
30+
from azure.core.pipeline.policies import AsyncHTTPPolicy, HTTPPolicy, SansIOHTTPPolicy
31+
from azure.core.pipeline.transport import AsyncHttpTransport, HttpTransport
3032

33+
PipelineType = Union[AsyncPipeline, Pipeline]
34+
PolicyType = Union[AsyncHTTPPolicy, HTTPPolicy, SansIOHTTPPolicy]
35+
TransportType = Union[AsyncHttpTransport, HttpTransport]
3136

32-
class AadClientBase(ABC):
33-
"""Sans I/O methods for AAD clients wrapping MSAL's OAuth client"""
3437

35-
def __init__(self, tenant_id, client_id, cache=None, **kwargs):
36-
# type: (str, str, Optional[TokenCache], **Any) -> None
37-
authority = kwargs.pop("authority", None)
38+
class AadClientBase(ABC):
39+
def __init__(self, tenant_id, client_id, authority=None, cache=None, **kwargs):
40+
# type: (str, str, Optional[str], Optional[TokenCache], **Any) -> None
3841
authority = normalize_authority(authority) if authority else get_default_authority()
39-
40-
token_endpoint = "/".join((authority, tenant_id, "oauth2/v2.0/token"))
41-
config = {"token_endpoint": token_endpoint}
42-
42+
self._token_endpoint = "/".join((authority, tenant_id, "oauth2/v2.0/token"))
4343
self._cache = cache or TokenCache()
44-
45-
self._client = Client(server_configuration=config, client_id=client_id)
46-
self._client.session.close()
47-
self._client.session = self._get_client_session(**kwargs)
44+
self._client_id = client_id
45+
self._pipeline = self._build_pipeline(**kwargs)
4846

4947
def get_cached_access_token(self, scopes):
50-
# type: (Iterable[str]) -> Optional[AccessToken]
48+
# type: (Sequence[str]) -> Optional[AccessToken]
5149
tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes))
5250
for token in tokens:
5351
expires_on = int(token["expires_on"])
@@ -56,35 +54,30 @@ def get_cached_access_token(self, scopes):
5654
return None
5755

5856
def get_cached_refresh_tokens(self, scopes):
57+
# type: (Sequence[str]) -> Sequence[dict]
5958
"""Assumes all cached refresh tokens belong to the same user"""
6059
return self._cache.find(TokenCache.CredentialType.REFRESH_TOKEN, target=list(scopes))
6160

62-
def obtain_token_by_authorization_code(self, code, redirect_uri, scopes, **kwargs):
63-
# type: (str, str, Iterable[str], **Any) -> AccessToken
64-
fn = functools.partial(
65-
self._client.obtain_token_by_authorization_code, code=code, redirect_uri=redirect_uri, **kwargs
66-
)
67-
return self._obtain_token(scopes, fn, **kwargs)
68-
69-
def obtain_token_by_refresh_token(self, refresh_token, scopes, **kwargs):
70-
# type: (str, Iterable[str], **Any) -> AccessToken
71-
fn = functools.partial(
72-
self._client.obtain_token_by_refresh_token,
73-
token_item=refresh_token,
74-
scope=scopes,
75-
rt_getter=lambda token: token["secret"],
76-
**kwargs
77-
)
78-
return self._obtain_token(scopes, fn, **kwargs)
61+
@abc.abstractmethod
62+
def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs):
63+
pass
64+
65+
@abc.abstractmethod
66+
def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs):
67+
pass
68+
69+
@abc.abstractmethod
70+
def _build_pipeline(self, config=None, policies=None, transport=None, **kwargs):
71+
pass
7972

8073
def _process_response(self, response, scopes, now):
81-
# type: (dict, Iterable[str], int) -> AccessToken
74+
# type: (dict, Sequence[str], int) -> AccessToken
8275
_raise_for_error(response)
8376

8477
# TokenCache.add mutates the response. In particular, it removes tokens.
8578
response_copy = copy.deepcopy(response)
8679

87-
self._cache.add(event={"response": response, "scope": scopes}, now=now)
80+
self._cache.add(event={"response": response, "scope": scopes, "client_id": self._client_id}, now=now)
8881
if "expires_on" in response_copy:
8982
expires_on = int(response_copy["expires_on"])
9083
elif "expires_in" in response_copy:
@@ -96,17 +89,41 @@ def _process_response(self, response, scopes, now):
9689
)
9790
return AccessToken(response_copy["access_token"], expires_on)
9891

99-
@abc.abstractmethod
100-
def _get_client_session(self, **kwargs):
101-
pass
102-
103-
@abc.abstractmethod
104-
def _obtain_token(self, scopes, fn, **kwargs):
105-
# type: (Iterable[str], Callable, **Any) -> AccessToken
106-
pass
92+
def _get_auth_code_request(self, scopes, code, redirect_uri, client_secret=None):
93+
# type: (str, str, Sequence[str], Optional[str]) -> HttpRequest
94+
95+
data = {
96+
"client_id": self._client_id,
97+
"code": code,
98+
"grant_type": "authorization_code",
99+
"redirect_uri": redirect_uri,
100+
"scope": " ".join(scopes),
101+
}
102+
if client_secret:
103+
data["client_secret"] = client_secret
104+
105+
request = HttpRequest(
106+
"POST", self._token_endpoint, headers={"Content-Type": "application/x-www-form-urlencoded"}, data=data
107+
)
108+
return request
109+
110+
def _get_refresh_token_request(self, scopes, refresh_token):
111+
# type: (str, Sequence[str]) -> HttpRequest
112+
113+
data = {
114+
"grant_type": "refresh_token",
115+
"refresh_token": refresh_token,
116+
"scope": " ".join(scopes),
117+
"client_id": self._client_id,
118+
}
119+
request = HttpRequest(
120+
"POST", self._token_endpoint, headers={"Content-Type": "application/x-www-form-urlencoded"}, data=data
121+
)
122+
return request
107123

108124

109125
def _scrub_secrets(response):
126+
# type: (dict) -> None
110127
for secret in ("access_token", "refresh_token"):
111128
if secret in response:
112129
response[secret] = "***"

0 commit comments

Comments
 (0)