Skip to content

Commit 117a6f5

Browse files
authored
token refresh offset (#12136)
* token refresh offset
1 parent 8a41f87 commit 117a6f5

23 files changed

+290
-82
lines changed

sdk/identity/azure-identity/azure/identity/_authn_client.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
UserAgentPolicy,
2323
)
2424
from azure.core.pipeline.transport import RequestsTransport, HttpRequest
25-
from ._constants import AZURE_CLI_CLIENT_ID
25+
from ._constants import AZURE_CLI_CLIENT_ID, DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY
2626
from ._internal import get_default_authority, normalize_authority
2727
from ._internal.user_agent import USER_AGENT
2828

@@ -65,17 +65,32 @@ def __init__(self, endpoint=None, authority=None, tenant=None, **kwargs): # pyl
6565
authority = normalize_authority(authority) if authority else get_default_authority()
6666
self._auth_url = "/".join((authority, tenant.strip("/"), "oauth2/v2.0/token"))
6767
self._cache = kwargs.get("cache") or TokenCache() # type: TokenCache
68+
self._token_refresh_retry_delay = DEFAULT_TOKEN_REFRESH_RETRY_DELAY
69+
self._token_refresh_offset = DEFAULT_REFRESH_OFFSET
70+
self._last_refresh_time = 0
6871

6972
@property
7073
def auth_url(self):
7174
return self._auth_url
7275

76+
def should_refresh(self, token):
77+
# type: (AccessToken) -> bool
78+
""" check if the token needs refresh or not
79+
"""
80+
expires_on = int(token.expires_on)
81+
now = int(time.time())
82+
if expires_on - now > self._token_refresh_offset:
83+
return False
84+
if now - self._last_refresh_time < self._token_refresh_retry_delay:
85+
return False
86+
return True
87+
7388
def get_cached_token(self, scopes):
7489
# type: (Iterable[str]) -> Optional[AccessToken]
7590
tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes))
7691
for token in tokens:
7792
expires_on = int(token["expires_on"])
78-
if expires_on - 300 > int(time.time()):
93+
if expires_on > int(time.time()):
7994
return AccessToken(token["secret"], expires_on)
8095
return None
8196

@@ -217,6 +232,7 @@ def request_token(
217232
# type: (...) -> AccessToken
218233
request = self._prepare_request(method, headers=headers, form_data=form_data, params=params)
219234
request_time = int(time.time())
235+
self._last_refresh_time = request_time # no matter succeed or not, update the last refresh time
220236
response = self._pipeline.run(request, stream=False, **kwargs)
221237
token = self._deserialize_and_cache_token(response=response, scopes=scopes, request_time=request_time)
222238
return token

sdk/identity/azure-identity/azure/identity/_constants.py

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
AZURE_CLI_CLIENT_ID = "04b07795-8ddb-461a-bbee-02f9e1bf7b46"
88
AZURE_VSCODE_CLIENT_ID = "aebc6443-996d-45c2-90f0-388ff96faa56"
99
VSCODE_CREDENTIALS_SECTION = "VS Code Azure"
10+
DEFAULT_REFRESH_OFFSET = 300
11+
DEFAULT_TOKEN_REFRESH_RETRY_DELAY = 30
1012

1113

1214
class KnownAuthorities:

sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,15 @@ def get_token(self, *scopes, **kwargs):
6464
self._authorization_code = None # auth codes are single-use
6565
return token
6666

67-
token = self._client.get_cached_access_token(scopes) or self._redeem_refresh_token(scopes, **kwargs)
67+
token = self._client.get_cached_access_token(scopes)
68+
if not token:
69+
token = self._redeem_refresh_token(scopes, **kwargs)
70+
elif self._client.should_refresh(token):
71+
try:
72+
self._redeem_refresh_token(scopes, **kwargs)
73+
except Exception: # pylint: disable=broad-except
74+
pass
75+
6876
if not token:
6977
raise ClientAuthenticationError(
7078
message="No authorization code, cached access token, or refresh token available."

sdk/identity/azure-identity/azure/identity/_credentials/certificate.py

+5
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
4848
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
4949
if not token:
5050
token = self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
51+
elif self._client.should_refresh(token):
52+
try:
53+
self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
54+
except Exception: # pylint: disable=broad-except
55+
pass
5156
return token
5257

5358
def _get_auth_client(self, tenant_id, client_id, **kwargs):

sdk/identity/azure-identity/azure/identity/_credentials/client_secret.py

+5
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ def get_token(self, *scopes, **kwargs):
4949
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
5050
if not token:
5151
token = self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
52+
elif self._client.should_refresh(token):
53+
try:
54+
self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
55+
except Exception: # pylint: disable=broad-except
56+
pass
5257
return token
5358

5459
def _get_auth_client(self, tenant_id, client_id, **kwargs):

sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py

+48-30
Original file line numberDiff line numberDiff line change
@@ -170,28 +170,37 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
170170

171171
token = self._client.get_cached_token(scopes)
172172
if not token:
173-
resource = scopes[0]
174-
if resource.endswith("/.default"):
175-
resource = resource[: -len("/.default")]
176-
params = dict({"api-version": "2018-02-01", "resource": resource}, **self._identity_config)
177-
173+
token = self._refresh_token(*scopes)
174+
elif self._client.should_refresh(token):
178175
try:
179-
token = self._client.request_token(scopes, method="GET", params=params)
180-
except HttpResponseError as ex:
181-
# 400 in response to a token request indicates managed identity is disabled,
182-
# or the identity with the specified client_id is not available
183-
if ex.status_code == 400:
184-
self._endpoint_available = False
185-
message = "ManagedIdentityCredential authentication unavailable. "
186-
if self._identity_config:
187-
message += "The requested identity has not been assigned to this resource."
188-
else:
189-
message += "No identity has been assigned to this resource."
190-
six.raise_from(CredentialUnavailableError(message=message), ex)
191-
192-
# any other error is unexpected
193-
six.raise_from(ClientAuthenticationError(message=ex.message, response=ex.response), None)
176+
token = self._refresh_token(*scopes)
177+
except Exception: # pylint: disable=broad-except
178+
pass
179+
180+
return token
194181

182+
def _refresh_token(self, *scopes):
183+
resource = scopes[0]
184+
if resource.endswith("/.default"):
185+
resource = resource[: -len("/.default")]
186+
params = dict({"api-version": "2018-02-01", "resource": resource}, **self._identity_config)
187+
188+
try:
189+
token = self._client.request_token(scopes, method="GET", params=params)
190+
except HttpResponseError as ex:
191+
# 400 in response to a token request indicates managed identity is disabled,
192+
# or the identity with the specified client_id is not available
193+
if ex.status_code == 400:
194+
self._endpoint_available = False
195+
message = "ManagedIdentityCredential authentication unavailable. "
196+
if self._identity_config:
197+
message += "The requested identity has not been assigned to this resource."
198+
else:
199+
message += "No identity has been assigned to this resource."
200+
six.raise_from(CredentialUnavailableError(message=message), ex)
201+
202+
# any other error is unexpected
203+
six.raise_from(ClientAuthenticationError(message=ex.message, response=ex.response), None)
195204
return token
196205

197206

@@ -227,16 +236,25 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
227236

228237
token = self._client.get_cached_token(scopes)
229238
if not token:
230-
resource = scopes[0]
231-
if resource.endswith("/.default"):
232-
resource = resource[: -len("/.default")]
233-
secret = os.environ.get(EnvironmentVariables.MSI_SECRET)
234-
if secret:
235-
# MSI_ENDPOINT and MSI_SECRET set -> App Service
236-
token = self._request_app_service_token(scopes=scopes, resource=resource, secret=secret)
237-
else:
238-
# only MSI_ENDPOINT set -> legacy-style MSI (Cloud Shell)
239-
token = self._request_legacy_token(scopes=scopes, resource=resource)
239+
token = self._refresh_token(*scopes)
240+
elif self._client.should_refresh(token):
241+
try:
242+
token = self._refresh_token(*scopes)
243+
except Exception: # pylint: disable=broad-except
244+
pass
245+
return token
246+
247+
def _refresh_token(self, *scopes):
248+
resource = scopes[0]
249+
if resource.endswith("/.default"):
250+
resource = resource[: -len("/.default")]
251+
secret = os.environ.get(EnvironmentVariables.MSI_SECRET)
252+
if secret:
253+
# MSI_ENDPOINT and MSI_SECRET set -> App Service
254+
token = self._request_app_service_token(scopes=scopes, resource=resource, secret=secret)
255+
else:
256+
# only MSI_ENDPOINT set -> legacy-style MSI (Cloud Shell)
257+
token = self._request_legacy_token(scopes=scopes, resource=resource)
240258
return token
241259

242260
def _request_app_service_token(self, scopes, resource, secret):

sdk/identity/azure-identity/azure/identity/_credentials/vscode_credential.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
from .._internal.aad_client import AadClient
1010

1111
if sys.platform.startswith("win"):
12-
from .win_vscode_adapter import get_credentials
12+
from .._internal.win_vscode_adapter import get_credentials
1313
elif sys.platform.startswith("darwin"):
14-
from .macos_vscode_adapter import get_credentials
14+
from .._internal.macos_vscode_adapter import get_credentials
1515
else:
16-
from .linux_vscode_adapter import get_credentials
16+
from .._internal.linux_vscode_adapter import get_credentials
1717

1818
if TYPE_CHECKING:
1919
# pylint:disable=unused-import,ungrouped-imports
@@ -47,9 +47,17 @@ def get_token(self, *scopes, **kwargs):
4747

4848
token = self._client.get_cached_access_token(scopes)
4949

50-
if token:
51-
return token
50+
if not token:
51+
token = self._redeem_refresh_token(scopes, **kwargs)
52+
elif self._client.should_refresh(token):
53+
try:
54+
self._redeem_refresh_token(scopes, **kwargs)
55+
except Exception: # pylint: disable=broad-except
56+
pass
57+
return token
5258

59+
def _redeem_refresh_token(self, scopes, **kwargs):
60+
# type: (Sequence[str], **Any) -> Optional[AccessToken]
5361
if not self._refresh_token:
5462
self._refresh_token = get_credentials()
5563
if not self._refresh_token:

sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from azure.core.credentials import AccessToken
1717
from azure.core.exceptions import ClientAuthenticationError
1818
from . import get_default_authority, normalize_authority
19+
from .._constants import DEFAULT_TOKEN_REFRESH_RETRY_DELAY, DEFAULT_REFRESH_OFFSET
1920

2021
try:
2122
from typing import TYPE_CHECKING
@@ -48,13 +49,16 @@ def __init__(self, tenant_id, client_id, authority=None, cache=None, **kwargs):
4849
self._cache = cache or TokenCache()
4950
self._client_id = client_id
5051
self._pipeline = self._build_pipeline(**kwargs)
52+
self._token_refresh_retry_delay = DEFAULT_TOKEN_REFRESH_RETRY_DELAY
53+
self._token_refresh_offset = DEFAULT_REFRESH_OFFSET
54+
self._last_refresh_time = 0
5155

5256
def get_cached_access_token(self, scopes, query=None):
5357
# type: (Iterable[str], Optional[dict]) -> Optional[AccessToken]
5458
tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes), query=query)
5559
for token in tokens:
5660
expires_on = int(token["expires_on"])
57-
if expires_on - 300 > int(time.time()):
61+
if expires_on > int(time.time()):
5862
return AccessToken(token["secret"], expires_on)
5963
return None
6064

@@ -63,6 +67,19 @@ def get_cached_refresh_tokens(self, scopes):
6367
"""Assumes all cached refresh tokens belong to the same user"""
6468
return self._cache.find(TokenCache.CredentialType.REFRESH_TOKEN, target=list(scopes))
6569

70+
def should_refresh(self, token):
71+
# type: (AccessToken) -> bool
72+
""" check if the token needs refresh or not
73+
"""
74+
expires_on = int(token.expires_on)
75+
now = int(time.time())
76+
if expires_on - now > self._token_refresh_offset:
77+
return False
78+
if now - self._last_refresh_time < self._token_refresh_retry_delay:
79+
return False
80+
return True
81+
82+
6683
@abc.abstractmethod
6784
def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs):
6885
pass
@@ -85,6 +102,7 @@ def _build_pipeline(self, config=None, policies=None, transport=None, **kwargs):
85102

86103
def _process_response(self, response, request_time):
87104
# type: (PipelineResponse, int) -> AccessToken
105+
self._last_refresh_time = request_time # no matter succeed or not, update the last refresh time
88106

89107
content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response)
90108

sdk/identity/azure-identity/azure/identity/aio/_authn_client.py

+1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ async def request_token( # pylint:disable=invalid-overridden-method
7575
) -> AccessToken:
7676
request = self._prepare_request(method, headers=headers, form_data=form_data, params=params)
7777
request_time = int(time.time())
78+
self._last_refresh_time = request_time # no matter succeed or not, update the last refresh time
7879
response = await self._pipeline.run(request, stream=False, **kwargs)
7980
token = self._deserialize_and_cache_token(response=response, scopes=scopes, request_time=request_time)
8081
return token

sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,11 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
8080
token = self._client.get_cached_access_token(scopes)
8181
if not token:
8282
token = await self._redeem_refresh_token(scopes, **kwargs)
83-
83+
elif self._client.should_refresh(token):
84+
try:
85+
await self._redeem_refresh_token(scopes, **kwargs)
86+
except Exception: # pylint: disable=broad-except
87+
pass
8488
if not token:
8589
raise ClientAuthenticationError(
8690
message="No authorization code, cached access token, or refresh token available."

sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py

+5
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
5454
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
5555
if not token:
5656
token = await self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
57+
elif self._client.should_refresh(token):
58+
try:
59+
await self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
60+
except Exception: # pylint: disable=broad-except
61+
pass
5762
return token
5863

5964
def _get_auth_client(self, tenant_id, client_id, **kwargs):

sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py

+5
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
5555
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
5656
if not token:
5757
token = await self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
58+
elif self._client.should_refresh(token):
59+
try:
60+
await self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
61+
except Exception: # pylint: disable=broad-except
62+
pass
5863
return token
5964

6065
def _get_auth_client(self, tenant_id, client_id, **kwargs):

0 commit comments

Comments
 (0)