Skip to content

WebClient support should get new access token when expired and client_credentials #6127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ private Mono<OAuth2AuthorizedClient> authorizedClientNotLoaded(String clientRegi
});
}

private Mono<? extends OAuth2AuthorizedClient> clientCredentials(
Mono<OAuth2AuthorizedClient> clientCredentials(
ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange) {
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,12 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
private final OAuth2AuthorizedClientResolver authorizedClientResolver;

public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
this(authorizedClientRepository, new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository));
}

ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientRepository authorizedClientRepository, OAuth2AuthorizedClientResolver authorizedClientResolver) {
this.authorizedClientRepository = authorizedClientRepository;
this.authorizedClientResolver = new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository);
this.authorizedClientResolver = authorizedClientResolver;
}

/**
Expand Down Expand Up @@ -245,13 +249,30 @@ private Mono<OAuth2AuthorizedClientResolver.Request> createRequest(ClientRequest
}

private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
if (shouldRefresh(authorizedClient)) {
ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
if (isClientCredentialsGrantType(clientRegistration) && hasTokenExpired(authorizedClient)) {
return createRequest(request)
.flatMap(r -> authorizeWithClientCredentials(clientRegistration, r));
} else if (shouldRefresh(authorizedClient)) {
return createRequest(request)
.flatMap(r -> refreshAuthorizedClient(next, authorizedClient, r));
}
return Mono.just(authorizedClient);
}

private boolean isClientCredentialsGrantType(ClientRegistration clientRegistration) {
return AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType());
}

private Mono<OAuth2AuthorizedClient> authorizeWithClientCredentials(ClientRegistration clientRegistration, OAuth2AuthorizedClientResolver.Request request) {
Authentication authentication = request.getAuthentication();
ServerWebExchange exchange = request.getExchange();

return this.authorizedClientResolver.clientCredentials(clientRegistration, authentication, exchange).
flatMap(result -> this.authorizedClientRepository.saveAuthorizedClient(result, authentication, exchange)
.thenReturn(result));
}

private Mono<OAuth2AuthorizedClient> refreshAuthorizedClient(ExchangeFunction next,
OAuth2AuthorizedClient authorizedClient, OAuth2AuthorizedClientResolver.Request r) {
ServerWebExchange exchange = r.getExchange();
Expand Down Expand Up @@ -280,6 +301,10 @@ private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) {
if (refreshToken == null) {
return false;
}
return hasTokenExpired(authorizedClient);
}

private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) {
Instant now = this.clock.instant();
Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt();
if (now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,12 +332,16 @@ private OAuth2AuthorizedClient getAuthorizedClient(String clientRegistrationId,
if (clientRegistration == null) {
throw new IllegalArgumentException("Could not find ClientRegistration with id " + clientRegistrationId);
}
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
if (isClientCredentialsGrantType(clientRegistration)) {
return getAuthorizedClient(clientRegistration, attrs);
}
throw new ClientAuthorizationRequiredException(clientRegistrationId);
}

private boolean isClientCredentialsGrantType(ClientRegistration clientRegistration) {
return AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType());
}


private OAuth2AuthorizedClient getAuthorizedClient(ClientRegistration clientRegistration,
Map<String, Object> attrs) {
Expand Down Expand Up @@ -366,7 +370,11 @@ private OAuth2AuthorizedClient getAuthorizedClient(ClientRegistration clientRegi
}

private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
if (shouldRefresh(authorizedClient)) {
ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
if (isClientCredentialsGrantType(clientRegistration) && hasTokenExpired(authorizedClient)) {
//Client credentials grant do not have refresh tokens but can expire so we need to get another one
return Mono.fromSupplier(() -> getAuthorizedClient(clientRegistration, request.attributes()));
} else if (shouldRefresh(authorizedClient)) {
return refreshAuthorizedClient(request, next, authorizedClient);
}
return Mono.just(authorizedClient);
Expand Down Expand Up @@ -407,6 +415,10 @@ private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) {
if (refreshToken == null) {
return false;
}
return hasTokenExpired(authorizedClient);
}

private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) {
Instant now = this.clock.instant();
Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt();
if (now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.web.reactive.function.client.OAuth2AuthorizedClientResolver.Request;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
Expand All @@ -67,6 +68,7 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
Expand All @@ -86,6 +88,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Mock
private ReactiveClientRegistrationRepository clientRegistrationRepository;

@Mock
private OAuth2AuthorizedClientResolver oAuth2AuthorizedClientResolver;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rename to authorizedClientResolver

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look's like this one change was missed


@Mock
private ServerWebExchange serverWebExchange;

Expand Down Expand Up @@ -144,6 +149,88 @@ public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() {
assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue());
}

@Test
public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() {
TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this");
ClientRegistration registration = TestClientRegistrations.clientCredentials().build();
String clientRegistrationId = registration.getClientId();

this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository, this.oAuth2AuthorizedClientResolver);

OAuth2AccessToken newAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
"new-token",
Instant.now(),
Instant.now().plus(Duration.ofDays(1)));
OAuth2AuthorizedClient newAuthorizedClient = new OAuth2AuthorizedClient(registration,
"principalName", newAccessToken, null);
Request r = new Request(clientRegistrationId, authentication, null);
when(this.oAuth2AuthorizedClientResolver.clientCredentials(any(), any(), any())).thenReturn(Mono.just(newAuthorizedClient));
when(this.oAuth2AuthorizedClientResolver.createDefaultedRequest(any(), any(), any())).thenReturn(Mono.just(r));

when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, otherwise I end up with a null pointer in the authorizedWithClientCredentials method line 273


Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1));

OAuth2AccessToken accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(),
this.accessToken.getTokenValue(),
issuedAt,
accessTokenExpiresAt);


OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration,
"principalName", accessToken, null);
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(oauth2AuthorizedClient(authorizedClient))
.build();


this.function.filter(request, this.exchange)
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
.block();

verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any());
verify(this.oAuth2AuthorizedClientResolver).clientCredentials(any(), any(), any());
verify(this.oAuth2AuthorizedClientResolver).createDefaultedRequest(any(), any(), any());

List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1);
ClientRequest request1 = requests.get(0);
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer new-token");
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request1)).isEmpty();
}

@Test
public void filterWhenClientCredentialsTokenNotExpiredThenUseCurrentToken() {
TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this");
ClientRegistration registration = TestClientRegistrations.clientCredentials().build();

this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository, this.oAuth2AuthorizedClientResolver);

OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration,
"principalName", this.accessToken, null);
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(oauth2AuthorizedClient(authorizedClient))
.build();

this.function.filter(request, this.exchange)
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
.block();

verify(this.oAuth2AuthorizedClientResolver, never()).clientCredentials(any(), any(), any());
verify(this.oAuth2AuthorizedClientResolver, never()).createDefaultedRequest(any(), any(), any());

List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1);
ClientRequest request1 = requests.get(0);
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request1)).isEmpty();
}

@Test
public void filterWhenRefreshRequiredThenRefresh() {
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -423,6 +424,80 @@ public void filterWhenRefreshRequiredThenRefresh() {
assertThat(getBody(request1)).isEmpty();
}

@Test
public void filterWhenClientCredentialsTokenNotExpiredThenUseCurrentToken() {
this.registration = TestClientRegistrations.clientCredentials().build();

this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
this.authorizedClientRepository);
this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient);

OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken, null);
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(oauth2AuthorizedClient(authorizedClient))
.attributes(authentication(this.authentication))
.build();

this.function.filter(request, this.exchange).block();

verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), eq(this.authentication), any(), any());

verify(clientCredentialsTokenResponseClient, never()).getTokenResponse(any());

List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1);

ClientRequest request1 = requests.get(0);
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request1)).isEmpty();
}

@Test
public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() {
this.registration = TestClientRegistrations.clientCredentials().build();

OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses
.accessTokenResponse().build();
when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(
accessTokenResponse);

Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1));

this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(),
this.accessToken.getTokenValue(),
issuedAt,
accessTokenExpiresAt);
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
this.authorizedClientRepository);
this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient);

OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken, null);
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(oauth2AuthorizedClient(authorizedClient))
.attributes(authentication(this.authentication))
.build();

this.function.filter(request, this.exchange).block();

verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(this.authentication), any(), any());

verify(clientCredentialsTokenResponseClient).getTokenResponse(any());

List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1);

ClientRequest request1 = requests.get(0);
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token");
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request1)).isEmpty();
}

@Test
public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() {
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
Expand Down