From b50006de1f5d5b23bd0f5328b36639b87fac2b7d Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Fri, 9 May 2025 00:34:37 +0000 Subject: [PATCH 1/5] Remove AccessToken::is_expired() --- sdk/core/azure_core/src/credentials.rs | 9 +--- .../src/http/policies/bearer_token_policy.rs | 50 +++++++++++++++---- .../azure_identity/src/credentials/cache.rs | 12 +++-- 3 files changed, 50 insertions(+), 21 deletions(-) diff --git a/sdk/core/azure_core/src/credentials.rs b/sdk/core/azure_core/src/credentials.rs index e34ae3b4a8..b1532fac6b 100644 --- a/sdk/core/azure_core/src/credentials.rs +++ b/sdk/core/azure_core/src/credentials.rs @@ -4,7 +4,7 @@ //! Azure authentication and authorization. use serde::{Deserialize, Serialize}; -use std::{borrow::Cow, fmt::Debug, time::Duration}; +use std::{borrow::Cow, fmt::Debug}; use typespec_client_core::date::OffsetDateTime; /// Default Azure authorization scope. @@ -85,13 +85,6 @@ impl AccessToken { expires_on, } } - - /// Check if the token is expired within a given duration. - /// - /// If no duration is provided, then the default duration of 30 seconds is used. - pub fn is_expired(&self, window: Option) -> bool { - self.expires_on < OffsetDateTime::now_utc() + window.unwrap_or(Duration::from_secs(30)) - } } /// Represents a credential capable of providing an OAuth token. diff --git a/sdk/core/azure_core/src/http/policies/bearer_token_policy.rs b/sdk/core/azure_core/src/http/policies/bearer_token_policy.rs index eed1551cf7..9ff1511b09 100644 --- a/sdk/core/azure_core/src/http/policies/bearer_token_policy.rs +++ b/sdk/core/azure_core/src/http/policies/bearer_token_policy.rs @@ -13,6 +13,7 @@ use async_lock::RwLock; use async_trait::async_trait; use std::sync::Arc; use std::time::Duration; +use typespec_client_core::date::OffsetDateTime; use typespec_client_core::http::{Context, Request}; /// Authentication policy for a bearer token. @@ -23,9 +24,6 @@ pub struct BearerTokenCredentialPolicy { access_token: Arc>>, } -/// Default timeout in seconds before refreshing a new token. -const DEFAULT_REFRESH_TIME: Duration = Duration::from_secs(120); - impl BearerTokenCredentialPolicy { pub fn new(credential: Arc, scopes: A) -> Self where @@ -63,16 +61,44 @@ impl Policy for BearerTokenCredentialPolicy { ) -> PolicyResult { let access_token = self.access_token.read().await; - if let Some(token) = &(*access_token) { - if token.is_expired(Some(DEFAULT_REFRESH_TIME)) { + match access_token.as_ref() { + None => { + // cache is empty. Upgrade the lock and acquire a token, provided another thread hasn't already done so + drop(access_token); + let mut access_token = self.access_token.write().await; + if access_token.is_none() { + *access_token = Some(self.credential.get_token(&self.scopes()).await?); + } + } + Some(token) if should_refresh(&token.expires_on) => { + // token is expired or within its refresh window. Upgrade the lock and + // acquire a new token, provided another thread hasn't already done so + let expires_on = token.expires_on; drop(access_token); let mut access_token = self.access_token.write().await; - *access_token = Some(self.credential.get_token(&self.scopes()).await?); + // access_token shouldn't be None here, but check anyway to guarantee unwrap won't panic + if access_token.is_none() || access_token.as_ref().unwrap().expires_on == expires_on + { + match self.credential.get_token(&self.scopes()).await { + Ok(new_token) => { + *access_token = Some(new_token); + } + Err(e) + if access_token.is_none() + || expires_on <= OffsetDateTime::now_utc() => + { + // propagate this error because we can't proceed without a new token + return Err(e); + } + Err(_) => { + // ignore this error because the cached token is still valid + } + } + } + } + Some(_) => { + // do nothing; cached token is valid and not within its refresh window } - } else { - drop(access_token); - let mut access_token = self.access_token.write().await; - *access_token = Some(self.credential.get_token(&self.scopes()).await?); } let access_token = self.access_token().await.ok_or_else(|| { @@ -86,3 +112,7 @@ impl Policy for BearerTokenCredentialPolicy { next[0].send(ctx, request, &next[1..]).await } } + +fn should_refresh(expires_on: &OffsetDateTime) -> bool { + *expires_on <= OffsetDateTime::now_utc() + Duration::from_secs(300) +} diff --git a/sdk/identity/azure_identity/src/credentials/cache.rs b/sdk/identity/azure_identity/src/credentials/cache.rs index e18c96adfd..63ba14c5f4 100644 --- a/sdk/identity/azure_identity/src/credentials/cache.rs +++ b/sdk/identity/azure_identity/src/credentials/cache.rs @@ -5,7 +5,9 @@ use async_lock::RwLock; use azure_core::credentials::AccessToken; use futures::Future; use std::collections::HashMap; +use std::time::Duration; use tracing::trace; +use typespec_client_core::date::OffsetDateTime; #[derive(Debug)] pub(crate) struct TokenCache(RwLock, AccessToken>>); @@ -24,7 +26,7 @@ impl TokenCache { let token_cache = self.0.read().await; let scopes = scopes.iter().map(ToString::to_string).collect::>(); if let Some(token) = token_cache.get(&scopes) { - if !token.is_expired(None) { + if !should_refresh(token) { trace!("returning cached token"); return Ok(token.clone()); } @@ -37,7 +39,7 @@ impl TokenCache { // check again in case another thread refreshed the token while we were // waiting on the write lock if let Some(token) = token_cache.get(&scopes) { - if !token.is_expired(None) { + if !should_refresh(token) { trace!("returning token that was updated while waiting on write lock"); return Ok(token.clone()); } @@ -61,6 +63,10 @@ impl Default for TokenCache { } } +fn should_refresh(token: &AccessToken) -> bool { + token.expires_on <= OffsetDateTime::now_utc() + Duration::from_secs(300) +} + #[cfg(test)] mod tests { use super::*; @@ -106,7 +112,7 @@ mod tests { let resource1 = &[STORAGE_TOKEN_SCOPE]; let resource2 = &[IOTHUB_TOKEN_SCOPE]; let secret_string = "test-token"; - let expires_on = OffsetDateTime::now_utc() + Duration::from_secs(300); + let expires_on = OffsetDateTime::now_utc() + Duration::from_secs(3600); let access_token = AccessToken::new(Secret::new(secret_string), expires_on); let mock_credential = MockCredential::new(access_token); From 9e17d7eeed8dd1821ffcf17e97c44384d3420981 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Fri, 9 May 2025 00:34:43 +0000 Subject: [PATCH 2/5] changelog --- sdk/core/azure_core/CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sdk/core/azure_core/CHANGELOG.md b/sdk/core/azure_core/CHANGELOG.md index 120b0fd10d..8a896c3afe 100644 --- a/sdk/core/azure_core/CHANGELOG.md +++ b/sdk/core/azure_core/CHANGELOG.md @@ -6,8 +6,12 @@ ### Breaking Changes +- Removed `AccessToken::is_expired()` + ### Bugs Fixed +- `BearerTokenCredentialPolicy` returns an error when a proactive token refresh attempt fails + ### Other Changes ## 0.24.0 (2025-05-02) From 7774e3fb85d6bbad0cbdca23f070763cd058cfa5 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Wed, 14 May 2025 17:30:33 +0000 Subject: [PATCH 3/5] tests --- .../src/http/policies/bearer_token_policy.rs | 148 ++++++++++++++++++ 1 file changed, 148 insertions(+) diff --git a/sdk/core/azure_core/src/http/policies/bearer_token_policy.rs b/sdk/core/azure_core/src/http/policies/bearer_token_policy.rs index 9ff1511b09..ea3334f14b 100644 --- a/sdk/core/azure_core/src/http/policies/bearer_token_policy.rs +++ b/sdk/core/azure_core/src/http/policies/bearer_token_policy.rs @@ -116,3 +116,151 @@ impl Policy for BearerTokenCredentialPolicy { fn should_refresh(expires_on: &OffsetDateTime) -> bool { *expires_on <= OffsetDateTime::now_utc() + Duration::from_secs(300) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + credentials::{Secret, TokenCredential}, + http::{ + headers::{Headers, AUTHORIZATION}, + policies::Policy, + Request, Response, StatusCode, + }, + Bytes, Result, + }; + use async_trait::async_trait; + use azure_core_test::http::MockHttpClient; + use futures::FutureExt; + use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }; + use std::time::Duration; + use time::OffsetDateTime; + use typespec_client_core::http::{policies::TransportPolicy, Method, TransportOptions}; + + #[derive(Debug, Clone)] + struct MockCredential { + calls: Arc, + tokens: Arc<[AccessToken]>, + } + + impl MockCredential { + fn new(tokens: &[AccessToken]) -> Self { + Self { + tokens: tokens.into(), + calls: Arc::new(AtomicUsize::new(0)), + } + } + + fn get_token_calls(&self) -> usize { + self.calls.load(Ordering::SeqCst).saturating_sub(1) + } + } + + #[cfg_attr(target_arch = "wasm32", async_trait(?Send))] + #[cfg_attr(not(target_arch = "wasm32"), async_trait)] + impl TokenCredential for MockCredential { + async fn get_token(&self, _scopes: &[&str]) -> Result { + let i = self.calls.fetch_add(1, Ordering::SeqCst); + self.tokens + .get(i) + .ok_or_else(|| Error::message(ErrorKind::Credential, "no more mock tokens")) + .cloned() + } + } + + #[tokio::test] + async fn authn_error() { + // this mock's get_token() will return an error because it has no tokens + let credential = MockCredential::new(&[]); + let policy = BearerTokenCredentialPolicy::new(Arc::new(credential), ["scope"]); + let client = MockHttpClient::new(|_| panic!("expected an error from get_token")); + let transport = Arc::new(TransportPolicy::new(TransportOptions::new(Arc::new( + client, + )))); + let mut req = Request::new("https://localhost".parse().unwrap(), Method::Get); + + let err = policy + .send(&Context::default(), &mut req, &[transport.clone()]) + .await + .expect_err("request should fail"); + + assert_eq!(ErrorKind::Credential, *err.kind()); + } + + #[tokio::test] + async fn caches_token() { + let credential = MockCredential::new(&[AccessToken { + token: Secret::new("fake".to_string()), + expires_on: OffsetDateTime::now_utc() + Duration::from_secs(3600), + }]); + let policy = BearerTokenCredentialPolicy::new(Arc::new(credential), ["scope"]); + let client = Arc::new(MockHttpClient::new(|actual| { + async move { + let authz = actual.headers().get_str(&AUTHORIZATION)?; + + assert_eq!("Bearer fake", authz); + + Ok(Response::from_bytes( + StatusCode::Ok, + Headers::new(), + Bytes::new(), + )) + } + .boxed() + })); + let transport = Arc::new(TransportPolicy::new(TransportOptions::new(client))); + + for _ in 0..3 { + let ctx = Context::default(); + let mut req = Request::new("https://localhost".parse().unwrap(), Method::Get); + policy + .send(&ctx, &mut req, &[transport.clone()]) + .await + .unwrap(); + } + } + + #[tokio::test] + async fn refreshes_token() { + let credential = Arc::new(MockCredential::new(&[ + AccessToken { + token: Secret::new("1".to_string()), + expires_on: OffsetDateTime::now_utc() - Duration::from_secs(1), + }, + AccessToken { + token: Secret::new("2".to_string()), + expires_on: OffsetDateTime::now_utc() + Duration::from_secs(3600), + }, + ])); + let policy = BearerTokenCredentialPolicy::new(credential.clone(), ["scope"]); + let client = Arc::new(MockHttpClient::new(move |actual| { + let credential = credential.clone(); + async move { + let authz = actual.headers().get_str(&AUTHORIZATION)?; + + let expected = &credential.tokens[credential.get_token_calls()]; + assert_eq!(format!("Bearer {}", expected.token.secret()), authz); + + Ok(Response::from_bytes( + StatusCode::Ok, + Headers::new(), + Bytes::new(), + )) + } + .boxed() + })); + let transport = Arc::new(TransportPolicy::new(TransportOptions::new(client))); + + for _ in 0..3 { + let ctx = Context::default(); + let mut req = Request::new("https://localhost".parse().unwrap(), Method::Get); + policy + .send(&ctx, &mut req, &[transport.clone()]) + .await + .unwrap(); + } + } +} From 7fbb979d7acad56fd3cf6989c986ce363baeda58 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Fri, 16 May 2025 16:00:47 +0000 Subject: [PATCH 4/5] test concurrency --- .../src/http/policies/bearer_token_policy.rs | 98 ++++++++++--------- 1 file changed, 52 insertions(+), 46 deletions(-) diff --git a/sdk/core/azure_core/src/http/policies/bearer_token_policy.rs b/sdk/core/azure_core/src/http/policies/bearer_token_policy.rs index ea3334f14b..6af1999c34 100644 --- a/sdk/core/azure_core/src/http/policies/bearer_token_policy.rs +++ b/sdk/core/azure_core/src/http/policies/bearer_token_policy.rs @@ -149,13 +149,23 @@ mod tests { impl MockCredential { fn new(tokens: &[AccessToken]) -> Self { Self { - tokens: tokens.into(), calls: Arc::new(AtomicUsize::new(0)), + tokens: tokens.into(), } } fn get_token_calls(&self) -> usize { - self.calls.load(Ordering::SeqCst).saturating_sub(1) + self.calls.load(Ordering::SeqCst) + } + } + + // ensure the number of get_token() calls matches the number of tokens + // in a test case i.e., that the policy called get_token() as expected + impl Drop for MockCredential { + fn drop(&mut self) { + if self.tokens.len() > 0 { + assert_eq!(self.calls.load(Ordering::SeqCst), self.tokens.len()); + } } } @@ -190,18 +200,18 @@ mod tests { assert_eq!(ErrorKind::Credential, *err.kind()); } - #[tokio::test] - async fn caches_token() { - let credential = MockCredential::new(&[AccessToken { - token: Secret::new("fake".to_string()), - expires_on: OffsetDateTime::now_utc() + Duration::from_secs(3600), - }]); - let policy = BearerTokenCredentialPolicy::new(Arc::new(credential), ["scope"]); - let client = Arc::new(MockHttpClient::new(|actual| { + async fn run_test(tokens: &[AccessToken]) { + let credential = Arc::new(MockCredential::new(tokens)); + let policy = BearerTokenCredentialPolicy::new(credential.clone(), ["scope"]); + let client = Arc::new(MockHttpClient::new(move |actual| { + let credential = credential.clone(); async move { let authz = actual.headers().get_str(&AUTHORIZATION)?; + // e.g. if this is the first request, we expect 1 get_token call and tokens[0] in the header + let i = credential.get_token_calls().saturating_sub(1); + let expected = &credential.tokens[i]; - assert_eq!("Bearer fake", authz); + assert_eq!(format!("Bearer {}", expected.token.secret()), authz); Ok(Response::from_bytes( StatusCode::Ok, @@ -213,19 +223,41 @@ mod tests { })); let transport = Arc::new(TransportPolicy::new(TransportOptions::new(client))); - for _ in 0..3 { - let ctx = Context::default(); - let mut req = Request::new("https://localhost".parse().unwrap(), Method::Get); - policy - .send(&ctx, &mut req, &[transport.clone()]) + let mut handles = vec![]; + for _ in 0..4 { + let policy = policy.clone(); + let transport = transport.clone(); + let handle = tokio::spawn(async move { + let ctx = Context::default(); + let mut req = Request::new("https://localhost".parse().unwrap(), Method::Get); + policy + .send(&ctx, &mut req, &[transport.clone()]) + .await + .expect("successful request"); + }); + handles.push(handle); + } + + for handle in handles { + tokio::time::timeout(Duration::from_secs(2), handle) .await - .unwrap(); + .expect("task timed out after 2 seconds") + .expect("completed task"); } } + #[tokio::test] + async fn caches_token() { + run_test(&[AccessToken { + token: Secret::new("fake".to_string()), + expires_on: OffsetDateTime::now_utc() + Duration::from_secs(3600), + }]) + .await; + } + #[tokio::test] async fn refreshes_token() { - let credential = Arc::new(MockCredential::new(&[ + run_test(&[ AccessToken { token: Secret::new("1".to_string()), expires_on: OffsetDateTime::now_utc() - Duration::from_secs(1), @@ -234,33 +266,7 @@ mod tests { token: Secret::new("2".to_string()), expires_on: OffsetDateTime::now_utc() + Duration::from_secs(3600), }, - ])); - let policy = BearerTokenCredentialPolicy::new(credential.clone(), ["scope"]); - let client = Arc::new(MockHttpClient::new(move |actual| { - let credential = credential.clone(); - async move { - let authz = actual.headers().get_str(&AUTHORIZATION)?; - - let expected = &credential.tokens[credential.get_token_calls()]; - assert_eq!(format!("Bearer {}", expected.token.secret()), authz); - - Ok(Response::from_bytes( - StatusCode::Ok, - Headers::new(), - Bytes::new(), - )) - } - .boxed() - })); - let transport = Arc::new(TransportPolicy::new(TransportOptions::new(client))); - - for _ in 0..3 { - let ctx = Context::default(); - let mut req = Request::new("https://localhost".parse().unwrap(), Method::Get); - policy - .send(&ctx, &mut req, &[transport.clone()]) - .await - .unwrap(); - } + ]) + .await; } } From eb1bc3b32c73b20ef45d14fca5e9c9d9448392a6 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Fri, 16 May 2025 16:37:54 +0000 Subject: [PATCH 5/5] =?UTF-8?q?=F0=9F=93=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sdk/core/azure_core/src/http/policies/bearer_token_policy.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/core/azure_core/src/http/policies/bearer_token_policy.rs b/sdk/core/azure_core/src/http/policies/bearer_token_policy.rs index 6af1999c34..36845ee13f 100644 --- a/sdk/core/azure_core/src/http/policies/bearer_token_policy.rs +++ b/sdk/core/azure_core/src/http/policies/bearer_token_policy.rs @@ -163,8 +163,8 @@ mod tests { // in a test case i.e., that the policy called get_token() as expected impl Drop for MockCredential { fn drop(&mut self) { - if self.tokens.len() > 0 { - assert_eq!(self.calls.load(Ordering::SeqCst), self.tokens.len()); + if !self.tokens.is_empty() { + assert_eq!(self.tokens.len(), self.calls.load(Ordering::SeqCst)); } } }