4
4
# ------------------------------------
5
5
import abc
6
6
import copy
7
- import functools
8
7
import time
9
8
10
- try :
11
- from typing import TYPE_CHECKING
12
- except ImportError :
13
- TYPE_CHECKING = False
14
-
15
9
from msal import TokenCache
16
- from msal .oauth2cli .oauth2 import Client
17
10
11
+ from azure .core .pipeline .transport import HttpRequest
18
12
from azure .core .credentials import AccessToken
19
13
from azure .core .exceptions import ClientAuthenticationError
20
14
from . import get_default_authority , normalize_authority
21
15
16
+ try :
17
+ from typing import TYPE_CHECKING
18
+ except ImportError :
19
+ TYPE_CHECKING = False
20
+
22
21
try :
23
22
ABC = abc .ABC
24
23
except AttributeError : # Python 2.7, abc exists, but not ABC
25
24
ABC = abc .ABCMeta ("ABC" , (object ,), {"__slots__" : ()}) # type: ignore
26
25
27
26
if TYPE_CHECKING :
28
27
# pylint:disable=unused-import,ungrouped-imports
29
- from typing import Any , Callable , Iterable , Optional
28
+ from typing import Any , Optional , Sequence , Union
29
+ from azure .core .pipeline import AsyncPipeline , Pipeline
30
+ from azure .core .pipeline .policies import AsyncHTTPPolicy , HTTPPolicy , SansIOHTTPPolicy
31
+ from azure .core .pipeline .transport import AsyncHttpTransport , HttpTransport
30
32
33
+ PipelineType = Union [AsyncPipeline , Pipeline ]
34
+ PolicyType = Union [AsyncHTTPPolicy , HTTPPolicy , SansIOHTTPPolicy ]
35
+ TransportType = Union [AsyncHttpTransport , HttpTransport ]
31
36
32
- class AadClientBase (ABC ):
33
- """Sans I/O methods for AAD clients wrapping MSAL's OAuth client"""
34
37
35
- def __init__ ( self , tenant_id , client_id , cache = None , ** kwargs ):
36
- # type: (str, str, Optional[TokenCache] , **Any) -> None
37
- authority = kwargs . pop ( "authority" , None )
38
+ class AadClientBase ( ABC ):
39
+ def __init__ ( self , tenant_id , client_id , authority = None , cache = None , ** kwargs ):
40
+ # type: (str, str, Optional[str], Optional[TokenCache], **Any) -> None
38
41
authority = normalize_authority (authority ) if authority else get_default_authority ()
39
-
40
- token_endpoint = "/" .join ((authority , tenant_id , "oauth2/v2.0/token" ))
41
- config = {"token_endpoint" : token_endpoint }
42
-
42
+ self ._token_endpoint = "/" .join ((authority , tenant_id , "oauth2/v2.0/token" ))
43
43
self ._cache = cache or TokenCache ()
44
-
45
- self ._client = Client (server_configuration = config , client_id = client_id )
46
- self ._client .session .close ()
47
- self ._client .session = self ._get_client_session (** kwargs )
44
+ self ._client_id = client_id
45
+ self ._pipeline = self ._build_pipeline (** kwargs )
48
46
49
47
def get_cached_access_token (self , scopes ):
50
- # type: (Iterable [str]) -> Optional[AccessToken]
48
+ # type: (Sequence [str]) -> Optional[AccessToken]
51
49
tokens = self ._cache .find (TokenCache .CredentialType .ACCESS_TOKEN , target = list (scopes ))
52
50
for token in tokens :
53
51
expires_on = int (token ["expires_on" ])
@@ -56,35 +54,30 @@ def get_cached_access_token(self, scopes):
56
54
return None
57
55
58
56
def get_cached_refresh_tokens (self , scopes ):
57
+ # type: (Sequence[str]) -> Sequence[dict]
59
58
"""Assumes all cached refresh tokens belong to the same user"""
60
59
return self ._cache .find (TokenCache .CredentialType .REFRESH_TOKEN , target = list (scopes ))
61
60
62
- def obtain_token_by_authorization_code (self , code , redirect_uri , scopes , ** kwargs ):
63
- # type: (str, str, Iterable[str], **Any) -> AccessToken
64
- fn = functools .partial (
65
- self ._client .obtain_token_by_authorization_code , code = code , redirect_uri = redirect_uri , ** kwargs
66
- )
67
- return self ._obtain_token (scopes , fn , ** kwargs )
68
-
69
- def obtain_token_by_refresh_token (self , refresh_token , scopes , ** kwargs ):
70
- # type: (str, Iterable[str], **Any) -> AccessToken
71
- fn = functools .partial (
72
- self ._client .obtain_token_by_refresh_token ,
73
- token_item = refresh_token ,
74
- scope = scopes ,
75
- rt_getter = lambda token : token ["secret" ],
76
- ** kwargs
77
- )
78
- return self ._obtain_token (scopes , fn , ** kwargs )
61
+ @abc .abstractmethod
62
+ def obtain_token_by_authorization_code (self , scopes , code , redirect_uri , client_secret = None , ** kwargs ):
63
+ pass
64
+
65
+ @abc .abstractmethod
66
+ def obtain_token_by_refresh_token (self , scopes , refresh_token , ** kwargs ):
67
+ pass
68
+
69
+ @abc .abstractmethod
70
+ def _build_pipeline (self , config = None , policies = None , transport = None , ** kwargs ):
71
+ pass
79
72
80
73
def _process_response (self , response , scopes , now ):
81
- # type: (dict, Iterable [str], int) -> AccessToken
74
+ # type: (dict, Sequence [str], int) -> AccessToken
82
75
_raise_for_error (response )
83
76
84
77
# TokenCache.add mutates the response. In particular, it removes tokens.
85
78
response_copy = copy .deepcopy (response )
86
79
87
- self ._cache .add (event = {"response" : response , "scope" : scopes }, now = now )
80
+ self ._cache .add (event = {"response" : response , "scope" : scopes , "client_id" : self . _client_id }, now = now )
88
81
if "expires_on" in response_copy :
89
82
expires_on = int (response_copy ["expires_on" ])
90
83
elif "expires_in" in response_copy :
@@ -96,17 +89,41 @@ def _process_response(self, response, scopes, now):
96
89
)
97
90
return AccessToken (response_copy ["access_token" ], expires_on )
98
91
99
- @abc .abstractmethod
100
- def _get_client_session (self , ** kwargs ):
101
- pass
102
-
103
- @abc .abstractmethod
104
- def _obtain_token (self , scopes , fn , ** kwargs ):
105
- # type: (Iterable[str], Callable, **Any) -> AccessToken
106
- pass
92
+ def _get_auth_code_request (self , scopes , code , redirect_uri , client_secret = None ):
93
+ # type: (str, str, Sequence[str], Optional[str]) -> HttpRequest
94
+
95
+ data = {
96
+ "client_id" : self ._client_id ,
97
+ "code" : code ,
98
+ "grant_type" : "authorization_code" ,
99
+ "redirect_uri" : redirect_uri ,
100
+ "scope" : " " .join (scopes ),
101
+ }
102
+ if client_secret :
103
+ data ["client_secret" ] = client_secret
104
+
105
+ request = HttpRequest (
106
+ "POST" , self ._token_endpoint , headers = {"Content-Type" : "application/x-www-form-urlencoded" }, data = data
107
+ )
108
+ return request
109
+
110
+ def _get_refresh_token_request (self , scopes , refresh_token ):
111
+ # type: (str, Sequence[str]) -> HttpRequest
112
+
113
+ data = {
114
+ "grant_type" : "refresh_token" ,
115
+ "refresh_token" : refresh_token ,
116
+ "scope" : " " .join (scopes ),
117
+ "client_id" : self ._client_id ,
118
+ }
119
+ request = HttpRequest (
120
+ "POST" , self ._token_endpoint , headers = {"Content-Type" : "application/x-www-form-urlencoded" }, data = data
121
+ )
122
+ return request
107
123
108
124
109
125
def _scrub_secrets (response ):
126
+ # type: (dict) -> None
110
127
for secret in ("access_token" , "refresh_token" ):
111
128
if secret in response :
112
129
response [secret ] = "***"
0 commit comments