Skip to content

Commit e4573de

Browse files
authored
[Key Vault] Support CAE in challenge auth policy (#37358)
1 parent b7934ec commit e4573de

File tree

22 files changed

+2060
-272
lines changed

22 files changed

+2060
-272
lines changed

sdk/keyvault/azure-keyvault-administration/CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
### Features Added
66
- Added support for service API version `7.6-preview.1`
7+
- Added support for Continuous Access Evaluation (CAE). `enable_cae=True` is passed to all `get_token` requests.
78

89
### Breaking Changes
910

@@ -12,6 +13,7 @@
1213
([#34744](https://github.com/Azure/azure-sdk-for-python/issues/34744))
1314

1415
### Other Changes
16+
- Updated minimum `azure-core` version to 1.31.0
1517
- Key Vault API version `7.6-preview.1` is now the default
1618

1719
## 4.4.0 (2024-02-22)

sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py

+163-21
Original file line numberDiff line numberDiff line change
@@ -16,33 +16,140 @@
1616

1717
from copy import deepcopy
1818
import time
19-
from typing import Any, Optional
19+
from typing import Any, Awaitable, Callable, cast, Optional, overload, TypeVar, Union
2020
from urllib.parse import urlparse
2121

22-
from azure.core.credentials import AccessToken
23-
from azure.core.credentials_async import AsyncTokenCredential
22+
from typing_extensions import ParamSpec
23+
24+
from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions
25+
from azure.core.credentials_async import AsyncSupportsTokenInfo, AsyncTokenCredential, AsyncTokenProvider
2426
from azure.core.pipeline import PipelineRequest, PipelineResponse
2527
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy
26-
from azure.core.rest import HttpRequest
28+
from azure.core.rest import AsyncHttpResponse, HttpRequest
2729

30+
from .http_challenge import HttpChallenge
2831
from . import http_challenge_cache as ChallengeCache
29-
from .challenge_auth_policy import _enforce_tls, _update_challenge
32+
from .challenge_auth_policy import _enforce_tls, _has_claims, _update_challenge
33+
34+
35+
P = ParamSpec("P")
36+
T = TypeVar("T")
37+
38+
39+
@overload
40+
async def await_result(func: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: ...
41+
42+
43+
@overload
44+
async def await_result(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: ...
45+
46+
47+
async def await_result(func: Callable[P, Union[T, Awaitable[T]]], *args: P.args, **kwargs: P.kwargs) -> T:
48+
"""If func returns an awaitable, await it.
49+
50+
:param func: The function to run.
51+
:type func: callable
52+
:param args: The positional arguments to pass to the function.
53+
:type args: list
54+
:rtype: any
55+
:return: The result of the function
56+
"""
57+
result = func(*args, **kwargs)
58+
if isinstance(result, Awaitable):
59+
return await result
60+
return result
61+
3062

3163
class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy):
3264
"""Policy for handling HTTP authentication challenges.
3365
3466
:param credential: An object which can provide an access token for the vault, such as a credential from
3567
:mod:`azure.identity.aio`
36-
:type credential: ~azure.core.credentials_async.AsyncTokenCredential
68+
:type credential: ~azure.core.credentials_async.AsyncTokenProvider
3769
"""
3870

39-
def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None:
40-
super().__init__(credential, *scopes, **kwargs)
41-
self._credential: AsyncTokenCredential = credential
42-
self._token: Optional[AccessToken] = None
71+
def __init__(self, credential: AsyncTokenProvider, *scopes: str, **kwargs: Any) -> None:
72+
# Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request
73+
super().__init__(credential, *scopes, enable_cae=True, **kwargs)
74+
self._credential: AsyncTokenProvider = credential
75+
self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None
4376
self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True)
4477
self._request_copy: Optional[HttpRequest] = None
4578

79+
async def send(
80+
self, request: PipelineRequest[HttpRequest]
81+
) -> PipelineResponse[HttpRequest, AsyncHttpResponse]:
82+
"""Authorize request with a bearer token and send it to the next policy.
83+
84+
We implement this method to account for the valid scenario where a Key Vault authentication challenge is
85+
immediately followed by a CAE claims challenge. The base class's implementation would return the second 401 to
86+
the caller, but we should handle that second challenge as well (and only return any third 401 response).
87+
88+
:param request: The pipeline request object
89+
:type request: ~azure.core.pipeline.PipelineRequest
90+
:return: The pipeline response object
91+
:rtype: ~azure.core.pipeline.PipelineResponse
92+
"""
93+
await await_result(self.on_request, request)
94+
response: PipelineResponse[HttpRequest, AsyncHttpResponse]
95+
try:
96+
response = await self.next.send(request)
97+
except Exception: # pylint:disable=broad-except
98+
await await_result(self.on_exception, request)
99+
raise
100+
await await_result(self.on_response, request, response)
101+
102+
if response.http_response.status_code == 401:
103+
return await self.handle_challenge_flow(request, response)
104+
return response
105+
106+
async def handle_challenge_flow(
107+
self,
108+
request: PipelineRequest[HttpRequest],
109+
response: PipelineResponse[HttpRequest, AsyncHttpResponse],
110+
consecutive_challenge: bool = False,
111+
) -> PipelineResponse[HttpRequest, AsyncHttpResponse]:
112+
"""Handle the challenge flow of Key Vault and CAE authentication.
113+
114+
:param request: The pipeline request object
115+
:type request: ~azure.core.pipeline.PipelineRequest
116+
:param response: The pipeline response object
117+
:type response: ~azure.core.pipeline.PipelineResponse
118+
:param bool consecutive_challenge: Whether the challenge is arriving immediately after another challenge.
119+
Consecutive challenges can only be valid if a Key Vault challenge is followed by a CAE claims challenge.
120+
True if the preceding challenge was a Key Vault challenge; False otherwise.
121+
122+
:return: The pipeline response object
123+
:rtype: ~azure.core.pipeline.PipelineResponse
124+
"""
125+
self._token = None # any cached token is invalid
126+
if "WWW-Authenticate" in response.http_response.headers:
127+
# If the previous challenge was a KV challenge and this one is too, return the 401
128+
claims_challenge = _has_claims(response.http_response.headers["WWW-Authenticate"])
129+
if consecutive_challenge and not claims_challenge:
130+
return response
131+
132+
request_authorized = await self.on_challenge(request, response)
133+
if request_authorized:
134+
# if we receive a challenge response, we retrieve a new token
135+
# which matches the new target. In this case, we don't want to remove
136+
# token from the request so clear the 'insecure_domain_change' tag
137+
request.context.options.pop("insecure_domain_change", False)
138+
try:
139+
response = await self.next.send(request)
140+
except Exception: # pylint:disable=broad-except
141+
await await_result(self.on_exception, request)
142+
raise
143+
144+
# If consecutive_challenge == True, this could be a third consecutive 401
145+
if response.http_response.status_code == 401 and not consecutive_challenge:
146+
# If the previous challenge wasn't from CAE, we can try this function one more time
147+
if not claims_challenge:
148+
return await self.handle_challenge_flow(request, response, consecutive_challenge=True)
149+
await await_result(self.on_response, request, response)
150+
return response
151+
152+
46153
async def on_request(self, request: PipelineRequest) -> None:
47154
_enforce_tls(request)
48155
challenge = ChallengeCache.get_challenge_for_url(request.http_request.url)
@@ -51,14 +158,10 @@ async def on_request(self, request: PipelineRequest) -> None:
51158
if self._need_new_token():
52159
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
53160
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
54-
# Exclude tenant for AD FS authentication
55-
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
56-
self._token = await self._credential.get_token(scope)
57-
else:
58-
self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id)
59-
60-
# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
61-
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore
161+
await self._request_kv_token(scope, challenge)
162+
163+
bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token
164+
request.http_request.headers["Authorization"] = f"Bearer {bearer_token}"
62165
return
63166

64167
# else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data,
@@ -78,7 +181,19 @@ async def on_request(self, request: PipelineRequest) -> None:
78181

79182
async def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> bool:
80183
try:
184+
# CAE challenges may not include a scope or tenant; cache from the previous challenge to use if necessary
185+
old_scope: Optional[str] = None
186+
old_tenant: Optional[str] = None
187+
cached_challenge = ChallengeCache.get_challenge_for_url(request.http_request.url)
188+
if cached_challenge:
189+
old_scope = cached_challenge.get_scope() or cached_challenge.get_resource() + "/.default"
190+
old_tenant = cached_challenge.tenant_id
191+
81192
challenge = _update_challenge(request, response)
193+
# CAE challenges may not include a scope or tenant; use the previous challenge's values if necessary
194+
if challenge.claims and old_scope:
195+
challenge._parameters["scope"] = old_scope # pylint:disable=protected-access
196+
challenge.tenant_id = old_tenant
82197
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
83198
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
84199
except ValueError:
@@ -104,11 +219,38 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons
104219
# The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication
105220
# For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648
106221
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
107-
await self.authorize_request(request, scope)
222+
await self.authorize_request(request, scope, claims=challenge.claims)
108223
else:
109-
await self.authorize_request(request, scope, tenant_id=challenge.tenant_id)
224+
await self.authorize_request(
225+
request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id
226+
)
110227

111228
return True
112229

113230
def _need_new_token(self) -> bool:
114-
return not self._token or self._token.expires_on - time.time() < 300
231+
now = time.time()
232+
refresh_on = getattr(self._token, "refresh_on", None)
233+
return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300
234+
235+
async def _request_kv_token(self, scope: str, challenge: HttpChallenge) -> None:
236+
"""Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault.
237+
238+
:param str scope: The scope for which to request a token.
239+
:param challenge: The challenge for the request being made.
240+
:type challenge: HttpChallenge
241+
"""
242+
# Exclude tenant for AD FS authentication
243+
exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs")
244+
# The AsyncSupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs
245+
if hasattr(self._credential, "get_token_info"):
246+
options: TokenRequestOptions = {"enable_cae": True}
247+
if challenge.tenant_id and not exclude_tenant:
248+
options["tenant_id"] = challenge.tenant_id
249+
self._token = await cast(AsyncSupportsTokenInfo, self._credential).get_token_info(scope, options=options)
250+
else:
251+
if exclude_tenant:
252+
self._token = await self._credential.get_token(scope, enable_cae=True)
253+
else:
254+
self._token = await cast(AsyncTokenCredential, self._credential).get_token(
255+
scope, tenant_id=challenge.tenant_id, enable_cae=True
256+
)

0 commit comments

Comments
 (0)