Skip to content

Commit 93c7eaa

Browse files
committed
(Async)SupportsTokenInfo support/tests
1 parent 8e65726 commit 93c7eaa

File tree

10 files changed

+623
-275
lines changed

10 files changed

+623
-275
lines changed

sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py

+37-18
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,18 @@
1616

1717
from copy import deepcopy
1818
import time
19-
from typing import Any, Awaitable, Callable, Optional, overload, TypeVar, Union
19+
from typing import Any, Awaitable, Callable, cast, Optional, overload, TypeVar, Union
2020
from urllib.parse import urlparse
2121

2222
from typing_extensions import ParamSpec
2323

24-
from azure.core.credentials import AccessToken
25-
from azure.core.credentials_async import AsyncTokenCredential
24+
from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions
25+
from azure.core.credentials_async import AsyncSupportsTokenInfo, AsyncTokenCredential, AsyncTokenProvider
2626
from azure.core.pipeline import PipelineRequest, PipelineResponse
2727
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy
2828
from azure.core.rest import AsyncHttpResponse, HttpRequest
2929

30+
from .http_challenge import HttpChallenge
3031
from . import http_challenge_cache as ChallengeCache
3132
from .challenge_auth_policy import _enforce_tls, _has_claims, _update_challenge
3233

@@ -64,14 +65,14 @@ class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy):
6465
6566
:param credential: An object which can provide an access token for the vault, such as a credential from
6667
:mod:`azure.identity.aio`
67-
:type credential: ~azure.core.credentials_async.AsyncTokenCredential
68+
:type credential: ~azure.core.credentials_async.AsyncTokenProvider
6869
"""
6970

70-
def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None:
71+
def __init__(self, credential: AsyncTokenProvider, *scopes: str, **kwargs: Any) -> None:
7172
# Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request
7273
super().__init__(credential, *scopes, enable_cae=True, **kwargs)
73-
self._credential: AsyncTokenCredential = credential
74-
self._token: Optional[AccessToken] = None
74+
self._credential: AsyncTokenProvider = credential
75+
self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None
7576
self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True)
7677
self._request_copy: Optional[HttpRequest] = None
7778

@@ -157,16 +158,10 @@ async def on_request(self, request: PipelineRequest) -> None:
157158
if self._need_new_token():
158159
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
159160
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
160-
# Exclude tenant for AD FS authentication
161-
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
162-
self._token = await self._credential.get_token(scope, enable_cae=True)
163-
else:
164-
self._token = await self._credential.get_token(
165-
scope, tenant_id=challenge.tenant_id, enable_cae=True
166-
)
167-
168-
# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
169-
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore
161+
await self._request_kv_token(scope, challenge)
162+
163+
bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token
164+
request.http_request.headers["Authorization"] = f"Bearer {bearer_token}"
170165
return
171166

172167
# else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data,
@@ -233,4 +228,28 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons
233228
return True
234229

235230
def _need_new_token(self) -> bool:
236-
return not self._token or self._token.expires_on - time.time() < 300
231+
now = time.time()
232+
refresh_on = getattr(self._token, "refresh_on", None)
233+
return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300
234+
235+
async def _request_kv_token(self, scope: str, challenge: HttpChallenge, **kwargs: Any) -> None:
236+
"""Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault.
237+
238+
:param str scope: The scope for which to request a token.
239+
:param challenge: The challenge for the request being made.
240+
"""
241+
# Exclude tenant for AD FS authentication
242+
exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs")
243+
# The AsyncSupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs
244+
if hasattr(self._credential, "get_token_info"):
245+
options: TokenRequestOptions = {"enable_cae": True}
246+
if challenge.tenant_id and not exclude_tenant:
247+
options["tenant_id"] = challenge.tenant_id
248+
self._token = await cast(AsyncSupportsTokenInfo, self._credential).get_token_info(scope, options=options)
249+
else:
250+
if exclude_tenant:
251+
self._token = await self._credential.get_token(scope, enable_cae=True)
252+
else:
253+
self._token = await cast(AsyncTokenCredential, self._credential).get_token(
254+
scope, tenant_id=challenge.tenant_id, enable_cae=True
255+
)

sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py

+43-15
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,17 @@
1616

1717
from copy import deepcopy
1818
import time
19-
from typing import Any, Optional
19+
from typing import Any, cast, Optional, Union
2020
from urllib.parse import urlparse
2121

22-
from azure.core.credentials import AccessToken, TokenCredential
22+
from azure.core.credentials import (
23+
AccessToken,
24+
AccessTokenInfo,
25+
TokenCredential,
26+
TokenProvider,
27+
TokenRequestOptions,
28+
SupportsTokenInfo,
29+
)
2330
from azure.core.exceptions import ServiceRequestError
2431
from azure.core.pipeline import PipelineRequest, PipelineResponse
2532
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
@@ -76,14 +83,15 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy):
7683
7784
:param credential: An object which can provide an access token for the vault, such as a credential from
7885
:mod:`azure.identity`
79-
:type credential: ~azure.core.credentials.TokenCredential
86+
:type credential: ~azure.core.credentials.TokenProvider
87+
:param str scopes: Lets you specify the type of access needed.
8088
"""
8189

82-
def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None:
90+
def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> None:
8391
# Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request
8492
super(ChallengeAuthPolicy, self).__init__(credential, *scopes, enable_cae=True, **kwargs)
85-
self._credential: TokenCredential = credential
86-
self._token: Optional[AccessToken] = None
93+
self._credential: TokenProvider = credential
94+
self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None
8795
self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True)
8896
self._request_copy: Optional[HttpRequest] = None
8997

@@ -166,14 +174,10 @@ def on_request(self, request: PipelineRequest) -> None:
166174
if self._need_new_token:
167175
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
168176
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
169-
# Exclude tenant for AD FS authentication
170-
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
171-
self._token = self._credential.get_token(scope, enable_cae=True)
172-
else:
173-
self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id, enable_cae=True)
174-
175-
# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
176-
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore
177+
self._request_kv_token(scope, challenge)
178+
179+
bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token
180+
request.http_request.headers["Authorization"] = f"Bearer {bearer_token}"
177181
return
178182

179183
# else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data,
@@ -238,4 +242,28 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) ->
238242

239243
@property
240244
def _need_new_token(self) -> bool:
241-
return not self._token or self._token.expires_on - time.time() < 300
245+
now = time.time()
246+
refresh_on = getattr(self._token, "refresh_on", None)
247+
return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300
248+
249+
def _request_kv_token(self, scope: str, challenge: HttpChallenge, **kwargs: Any) -> None:
250+
"""Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault.
251+
252+
:param str scope: The scope for which to request a token.
253+
:param challenge: The challenge for the request being made.
254+
"""
255+
# Exclude tenant for AD FS authentication
256+
exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs")
257+
# The SupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs
258+
if hasattr(self._credential, "get_token_info"):
259+
options: TokenRequestOptions = {"enable_cae": True}
260+
if challenge.tenant_id and not exclude_tenant:
261+
options["tenant_id"] = challenge.tenant_id
262+
self._token = cast(SupportsTokenInfo, self._credential).get_token_info(scope, options=options)
263+
else:
264+
if exclude_tenant:
265+
self._token = self._credential.get_token(scope, enable_cae=True)
266+
else:
267+
self._token = cast(TokenCredential, self._credential).get_token(
268+
scope, tenant_id=challenge.tenant_id, enable_cae=True
269+
)

sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py

+37-18
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,18 @@
1616

1717
from copy import deepcopy
1818
import time
19-
from typing import Any, Awaitable, Callable, Optional, overload, TypeVar, Union
19+
from typing import Any, Awaitable, Callable, cast, Optional, overload, TypeVar, Union
2020
from urllib.parse import urlparse
2121

2222
from typing_extensions import ParamSpec
2323

24-
from azure.core.credentials import AccessToken
25-
from azure.core.credentials_async import AsyncTokenCredential
24+
from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions
25+
from azure.core.credentials_async import AsyncSupportsTokenInfo, AsyncTokenCredential, AsyncTokenProvider
2626
from azure.core.pipeline import PipelineRequest, PipelineResponse
2727
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy
2828
from azure.core.rest import AsyncHttpResponse, HttpRequest
2929

30+
from .http_challenge import HttpChallenge
3031
from . import http_challenge_cache as ChallengeCache
3132
from .challenge_auth_policy import _enforce_tls, _has_claims, _update_challenge
3233

@@ -64,14 +65,14 @@ class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy):
6465
6566
:param credential: An object which can provide an access token for the vault, such as a credential from
6667
:mod:`azure.identity.aio`
67-
:type credential: ~azure.core.credentials_async.AsyncTokenCredential
68+
:type credential: ~azure.core.credentials_async.AsyncTokenProvider
6869
"""
6970

70-
def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None:
71+
def __init__(self, credential: AsyncTokenProvider, *scopes: str, **kwargs: Any) -> None:
7172
# Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request
7273
super().__init__(credential, *scopes, enable_cae=True, **kwargs)
73-
self._credential: AsyncTokenCredential = credential
74-
self._token: Optional[AccessToken] = None
74+
self._credential: AsyncTokenProvider = credential
75+
self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None
7576
self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True)
7677
self._request_copy: Optional[HttpRequest] = None
7778

@@ -157,16 +158,10 @@ async def on_request(self, request: PipelineRequest) -> None:
157158
if self._need_new_token():
158159
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
159160
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
160-
# Exclude tenant for AD FS authentication
161-
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
162-
self._token = await self._credential.get_token(scope, enable_cae=True)
163-
else:
164-
self._token = await self._credential.get_token(
165-
scope, tenant_id=challenge.tenant_id, enable_cae=True
166-
)
167-
168-
# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
169-
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore
161+
await self._request_kv_token(scope, challenge)
162+
163+
bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token
164+
request.http_request.headers["Authorization"] = f"Bearer {bearer_token}"
170165
return
171166

172167
# else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data,
@@ -233,4 +228,28 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons
233228
return True
234229

235230
def _need_new_token(self) -> bool:
236-
return not self._token or self._token.expires_on - time.time() < 300
231+
now = time.time()
232+
refresh_on = getattr(self._token, "refresh_on", None)
233+
return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300
234+
235+
async def _request_kv_token(self, scope: str, challenge: HttpChallenge, **kwargs: Any) -> None:
236+
"""Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault.
237+
238+
:param str scope: The scope for which to request a token.
239+
:param challenge: The challenge for the request being made.
240+
"""
241+
# Exclude tenant for AD FS authentication
242+
exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs")
243+
# The AsyncSupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs
244+
if hasattr(self._credential, "get_token_info"):
245+
options: TokenRequestOptions = {"enable_cae": True}
246+
if challenge.tenant_id and not exclude_tenant:
247+
options["tenant_id"] = challenge.tenant_id
248+
self._token = await cast(AsyncSupportsTokenInfo, self._credential).get_token_info(scope, options=options)
249+
else:
250+
if exclude_tenant:
251+
self._token = await self._credential.get_token(scope, enable_cae=True)
252+
else:
253+
self._token = await cast(AsyncTokenCredential, self._credential).get_token(
254+
scope, tenant_id=challenge.tenant_id, enable_cae=True
255+
)

sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py

+43-15
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,17 @@
1616

1717
from copy import deepcopy
1818
import time
19-
from typing import Any, Optional
19+
from typing import Any, cast, Optional, Union
2020
from urllib.parse import urlparse
2121

22-
from azure.core.credentials import AccessToken, TokenCredential
22+
from azure.core.credentials import (
23+
AccessToken,
24+
AccessTokenInfo,
25+
TokenCredential,
26+
TokenProvider,
27+
TokenRequestOptions,
28+
SupportsTokenInfo,
29+
)
2330
from azure.core.exceptions import ServiceRequestError
2431
from azure.core.pipeline import PipelineRequest, PipelineResponse
2532
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
@@ -76,14 +83,15 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy):
7683
7784
:param credential: An object which can provide an access token for the vault, such as a credential from
7885
:mod:`azure.identity`
79-
:type credential: ~azure.core.credentials.TokenCredential
86+
:type credential: ~azure.core.credentials.TokenProvider
87+
:param str scopes: Lets you specify the type of access needed.
8088
"""
8189

82-
def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None:
90+
def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> None:
8391
# Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request
8492
super(ChallengeAuthPolicy, self).__init__(credential, *scopes, enable_cae=True, **kwargs)
85-
self._credential: TokenCredential = credential
86-
self._token: Optional[AccessToken] = None
93+
self._credential: TokenProvider = credential
94+
self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None
8795
self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True)
8896
self._request_copy: Optional[HttpRequest] = None
8997

@@ -166,14 +174,10 @@ def on_request(self, request: PipelineRequest) -> None:
166174
if self._need_new_token:
167175
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
168176
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
169-
# Exclude tenant for AD FS authentication
170-
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
171-
self._token = self._credential.get_token(scope, enable_cae=True)
172-
else:
173-
self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id, enable_cae=True)
174-
175-
# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
176-
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore
177+
self._request_kv_token(scope, challenge)
178+
179+
bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token
180+
request.http_request.headers["Authorization"] = f"Bearer {bearer_token}"
177181
return
178182

179183
# else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data,
@@ -238,4 +242,28 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) ->
238242

239243
@property
240244
def _need_new_token(self) -> bool:
241-
return not self._token or self._token.expires_on - time.time() < 300
245+
now = time.time()
246+
refresh_on = getattr(self._token, "refresh_on", None)
247+
return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300
248+
249+
def _request_kv_token(self, scope: str, challenge: HttpChallenge, **kwargs: Any) -> None:
250+
"""Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault.
251+
252+
:param str scope: The scope for which to request a token.
253+
:param challenge: The challenge for the request being made.
254+
"""
255+
# Exclude tenant for AD FS authentication
256+
exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs")
257+
# The SupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs
258+
if hasattr(self._credential, "get_token_info"):
259+
options: TokenRequestOptions = {"enable_cae": True}
260+
if challenge.tenant_id and not exclude_tenant:
261+
options["tenant_id"] = challenge.tenant_id
262+
self._token = cast(SupportsTokenInfo, self._credential).get_token_info(scope, options=options)
263+
else:
264+
if exclude_tenant:
265+
self._token = self._credential.get_token(scope, enable_cae=True)
266+
else:
267+
self._token = cast(TokenCredential, self._credential).get_token(
268+
scope, tenant_id=challenge.tenant_id, enable_cae=True
269+
)

0 commit comments

Comments
 (0)