16
16
17
17
from copy import deepcopy
18
18
import time
19
- from typing import Any , Optional
19
+ from typing import Any , Awaitable , Callable , cast , Optional , overload , TypeVar , Union
20
20
from urllib .parse import urlparse
21
21
22
- from azure .core .credentials import AccessToken
23
- from azure .core .credentials_async import AsyncTokenCredential
22
+ from typing_extensions import ParamSpec
23
+
24
+ from azure .core .credentials import AccessToken , AccessTokenInfo , TokenRequestOptions
25
+ from azure .core .credentials_async import AsyncSupportsTokenInfo , AsyncTokenCredential , AsyncTokenProvider
24
26
from azure .core .pipeline import PipelineRequest , PipelineResponse
25
27
from azure .core .pipeline .policies import AsyncBearerTokenCredentialPolicy
26
- from azure .core .rest import HttpRequest
28
+ from azure .core .rest import AsyncHttpResponse , HttpRequest
27
29
30
+ from .http_challenge import HttpChallenge
28
31
from . import http_challenge_cache as ChallengeCache
29
- from .challenge_auth_policy import _enforce_tls , _update_challenge
32
+ from .challenge_auth_policy import _enforce_tls , _has_claims , _update_challenge
33
+
34
+
35
+ P = ParamSpec ("P" )
36
+ T = TypeVar ("T" )
37
+
38
+
39
+ @overload
40
+ async def await_result (func : Callable [P , Awaitable [T ]], * args : P .args , ** kwargs : P .kwargs ) -> T : ...
41
+
42
+
43
+ @overload
44
+ async def await_result (func : Callable [P , T ], * args : P .args , ** kwargs : P .kwargs ) -> T : ...
45
+
46
+
47
+ async def await_result (func : Callable [P , Union [T , Awaitable [T ]]], * args : P .args , ** kwargs : P .kwargs ) -> T :
48
+ """If func returns an awaitable, await it.
49
+
50
+ :param func: The function to run.
51
+ :type func: callable
52
+ :param args: The positional arguments to pass to the function.
53
+ :type args: list
54
+ :rtype: any
55
+ :return: The result of the function
56
+ """
57
+ result = func (* args , ** kwargs )
58
+ if isinstance (result , Awaitable ):
59
+ return await result
60
+ return result
61
+
30
62
31
63
class AsyncChallengeAuthPolicy (AsyncBearerTokenCredentialPolicy ):
32
64
"""Policy for handling HTTP authentication challenges.
33
65
34
66
:param credential: An object which can provide an access token for the vault, such as a credential from
35
67
:mod:`azure.identity.aio`
36
- :type credential: ~azure.core.credentials_async.AsyncTokenCredential
68
+ :type credential: ~azure.core.credentials_async.AsyncTokenProvider
37
69
"""
38
70
39
- def __init__ (self , credential : AsyncTokenCredential , * scopes : str , ** kwargs : Any ) -> None :
40
- super ().__init__ (credential , * scopes , ** kwargs )
41
- self ._credential : AsyncTokenCredential = credential
42
- self ._token : Optional [AccessToken ] = None
71
+ def __init__ (self , credential : AsyncTokenProvider , * scopes : str , ** kwargs : Any ) -> None :
72
+ # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request
73
+ super ().__init__ (credential , * scopes , enable_cae = True , ** kwargs )
74
+ self ._credential : AsyncTokenProvider = credential
75
+ self ._token : Optional [Union ["AccessToken" , "AccessTokenInfo" ]] = None
43
76
self ._verify_challenge_resource = kwargs .pop ("verify_challenge_resource" , True )
44
77
self ._request_copy : Optional [HttpRequest ] = None
45
78
79
+ async def send (
80
+ self , request : PipelineRequest [HttpRequest ]
81
+ ) -> PipelineResponse [HttpRequest , AsyncHttpResponse ]:
82
+ """Authorize request with a bearer token and send it to the next policy.
83
+
84
+ We implement this method to account for the valid scenario where a Key Vault authentication challenge is
85
+ immediately followed by a CAE claims challenge. The base class's implementation would return the second 401 to
86
+ the caller, but we should handle that second challenge as well (and only return any third 401 response).
87
+
88
+ :param request: The pipeline request object
89
+ :type request: ~azure.core.pipeline.PipelineRequest
90
+ :return: The pipeline response object
91
+ :rtype: ~azure.core.pipeline.PipelineResponse
92
+ """
93
+ await await_result (self .on_request , request )
94
+ response : PipelineResponse [HttpRequest , AsyncHttpResponse ]
95
+ try :
96
+ response = await self .next .send (request )
97
+ except Exception : # pylint:disable=broad-except
98
+ await await_result (self .on_exception , request )
99
+ raise
100
+ await await_result (self .on_response , request , response )
101
+
102
+ if response .http_response .status_code == 401 :
103
+ return await self .handle_challenge_flow (request , response )
104
+ return response
105
+
106
+ async def handle_challenge_flow (
107
+ self ,
108
+ request : PipelineRequest [HttpRequest ],
109
+ response : PipelineResponse [HttpRequest , AsyncHttpResponse ],
110
+ consecutive_challenge : bool = False ,
111
+ ) -> PipelineResponse [HttpRequest , AsyncHttpResponse ]:
112
+ """Handle the challenge flow of Key Vault and CAE authentication.
113
+
114
+ :param request: The pipeline request object
115
+ :type request: ~azure.core.pipeline.PipelineRequest
116
+ :param response: The pipeline response object
117
+ :type response: ~azure.core.pipeline.PipelineResponse
118
+ :param bool consecutive_challenge: Whether the challenge is arriving immediately after another challenge.
119
+ Consecutive challenges can only be valid if a Key Vault challenge is followed by a CAE claims challenge.
120
+ True if the preceding challenge was a Key Vault challenge; False otherwise.
121
+
122
+ :return: The pipeline response object
123
+ :rtype: ~azure.core.pipeline.PipelineResponse
124
+ """
125
+ self ._token = None # any cached token is invalid
126
+ if "WWW-Authenticate" in response .http_response .headers :
127
+ # If the previous challenge was a KV challenge and this one is too, return the 401
128
+ claims_challenge = _has_claims (response .http_response .headers ["WWW-Authenticate" ])
129
+ if consecutive_challenge and not claims_challenge :
130
+ return response
131
+
132
+ request_authorized = await self .on_challenge (request , response )
133
+ if request_authorized :
134
+ # if we receive a challenge response, we retrieve a new token
135
+ # which matches the new target. In this case, we don't want to remove
136
+ # token from the request so clear the 'insecure_domain_change' tag
137
+ request .context .options .pop ("insecure_domain_change" , False )
138
+ try :
139
+ response = await self .next .send (request )
140
+ except Exception : # pylint:disable=broad-except
141
+ await await_result (self .on_exception , request )
142
+ raise
143
+
144
+ # If consecutive_challenge == True, this could be a third consecutive 401
145
+ if response .http_response .status_code == 401 and not consecutive_challenge :
146
+ # If the previous challenge wasn't from CAE, we can try this function one more time
147
+ if not claims_challenge :
148
+ return await self .handle_challenge_flow (request , response , consecutive_challenge = True )
149
+ await await_result (self .on_response , request , response )
150
+ return response
151
+
152
+
46
153
async def on_request (self , request : PipelineRequest ) -> None :
47
154
_enforce_tls (request )
48
155
challenge = ChallengeCache .get_challenge_for_url (request .http_request .url )
@@ -51,14 +158,10 @@ async def on_request(self, request: PipelineRequest) -> None:
51
158
if self ._need_new_token ():
52
159
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
53
160
scope = challenge .get_scope () or challenge .get_resource () + "/.default"
54
- # Exclude tenant for AD FS authentication
55
- if challenge .tenant_id and challenge .tenant_id .lower ().endswith ("adfs" ):
56
- self ._token = await self ._credential .get_token (scope )
57
- else :
58
- self ._token = await self ._credential .get_token (scope , tenant_id = challenge .tenant_id )
59
-
60
- # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
61
- request .http_request .headers ["Authorization" ] = f"Bearer { self ._token .token } " # type: ignore
161
+ await self ._request_kv_token (scope , challenge )
162
+
163
+ bearer_token = cast (Union [AccessToken , AccessTokenInfo ], self ._token ).token
164
+ request .http_request .headers ["Authorization" ] = f"Bearer { bearer_token } "
62
165
return
63
166
64
167
# else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data,
@@ -78,7 +181,19 @@ async def on_request(self, request: PipelineRequest) -> None:
78
181
79
182
async def on_challenge (self , request : PipelineRequest , response : PipelineResponse ) -> bool :
80
183
try :
184
+ # CAE challenges may not include a scope or tenant; cache from the previous challenge to use if necessary
185
+ old_scope : Optional [str ] = None
186
+ old_tenant : Optional [str ] = None
187
+ cached_challenge = ChallengeCache .get_challenge_for_url (request .http_request .url )
188
+ if cached_challenge :
189
+ old_scope = cached_challenge .get_scope () or cached_challenge .get_resource () + "/.default"
190
+ old_tenant = cached_challenge .tenant_id
191
+
81
192
challenge = _update_challenge (request , response )
193
+ # CAE challenges may not include a scope or tenant; use the previous challenge's values if necessary
194
+ if challenge .claims and old_scope :
195
+ challenge ._parameters ["scope" ] = old_scope # pylint:disable=protected-access
196
+ challenge .tenant_id = old_tenant
82
197
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
83
198
scope = challenge .get_scope () or challenge .get_resource () + "/.default"
84
199
except ValueError :
@@ -104,11 +219,38 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons
104
219
# The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication
105
220
# For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648
106
221
if challenge .tenant_id and challenge .tenant_id .lower ().endswith ("adfs" ):
107
- await self .authorize_request (request , scope )
222
+ await self .authorize_request (request , scope , claims = challenge . claims )
108
223
else :
109
- await self .authorize_request (request , scope , tenant_id = challenge .tenant_id )
224
+ await self .authorize_request (
225
+ request , scope , claims = challenge .claims , tenant_id = challenge .tenant_id
226
+ )
110
227
111
228
return True
112
229
113
230
def _need_new_token (self ) -> bool :
114
- return not self ._token or self ._token .expires_on - time .time () < 300
231
+ now = time .time ()
232
+ refresh_on = getattr (self ._token , "refresh_on" , None )
233
+ return not self ._token or (refresh_on and refresh_on <= now ) or self ._token .expires_on - now < 300
234
+
235
+ async def _request_kv_token (self , scope : str , challenge : HttpChallenge ) -> None :
236
+ """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault.
237
+
238
+ :param str scope: The scope for which to request a token.
239
+ :param challenge: The challenge for the request being made.
240
+ :type challenge: HttpChallenge
241
+ """
242
+ # Exclude tenant for AD FS authentication
243
+ exclude_tenant = challenge .tenant_id and challenge .tenant_id .lower ().endswith ("adfs" )
244
+ # The AsyncSupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs
245
+ if hasattr (self ._credential , "get_token_info" ):
246
+ options : TokenRequestOptions = {"enable_cae" : True }
247
+ if challenge .tenant_id and not exclude_tenant :
248
+ options ["tenant_id" ] = challenge .tenant_id
249
+ self ._token = await cast (AsyncSupportsTokenInfo , self ._credential ).get_token_info (scope , options = options )
250
+ else :
251
+ if exclude_tenant :
252
+ self ._token = await self ._credential .get_token (scope , enable_cae = True )
253
+ else :
254
+ self ._token = await cast (AsyncTokenCredential , self ._credential ).get_token (
255
+ scope , tenant_id = challenge .tenant_id , enable_cae = True
256
+ )
0 commit comments