Skip to content

Commit 0441570

Browse files
vboulayeSteve Riesenberg
authored and
Steve Riesenberg
committed
Enable customizing headers in token requests
Adds the possibility to customize the headers of the access token request in AbstractWebClientReactiveOAuth2AccessTokenResponseClient, similarly to what is done in the AbstractOAuth2AuthorizationGrantRequestEntityConverter. Closes gh-10130
1 parent 1806ceb commit 0441570

File tree

5 files changed

+326
-6
lines changed

5 files changed

+326
-6
lines changed

Diff for: oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractWebClientReactiveOAuth2AccessTokenResponseClient.java

+64-3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import reactor.core.publisher.Mono;
2626

27+
import org.springframework.core.convert.converter.Converter;
2728
import org.springframework.http.HttpHeaders;
2829
import org.springframework.http.MediaType;
2930
import org.springframework.security.oauth2.client.registration.ClientRegistration;
@@ -65,6 +66,8 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
6566

6667
private WebClient webClient = WebClient.builder().build();
6768

69+
private Converter<T, HttpHeaders> headersConverter = this::populateTokenRequestHeaders;
70+
6871
AbstractWebClientReactiveOAuth2AccessTokenResponseClient() {
6972
}
7073

@@ -74,7 +77,12 @@ public Mono<OAuth2AccessTokenResponse> getTokenResponse(T grantRequest) {
7477
// @formatter:off
7578
return Mono.defer(() -> this.webClient.post()
7679
.uri(clientRegistration(grantRequest).getProviderDetails().getTokenUri())
77-
.headers((headers) -> populateTokenRequestHeaders(grantRequest, headers))
80+
.headers((headers) -> {
81+
HttpHeaders headersToAdd = getHeadersConverter().convert(grantRequest);
82+
if (headersToAdd != null) {
83+
headers.addAll(headersToAdd);
84+
}
85+
})
7886
.body(createTokenRequestBody(grantRequest))
7987
.exchange()
8088
.flatMap((response) -> readTokenResponse(grantRequest, response))
@@ -92,9 +100,10 @@ public Mono<OAuth2AccessTokenResponse> getTokenResponse(T grantRequest) {
92100
/**
93101
* Populates the headers for the token request.
94102
* @param grantRequest the grant request
95-
* @param headers the headers to populate
103+
* @return the headers populated for the token request
96104
*/
97-
private void populateTokenRequestHeaders(T grantRequest, HttpHeaders headers) {
105+
private HttpHeaders populateTokenRequestHeaders(T grantRequest) {
106+
HttpHeaders headers = new HttpHeaders();
98107
ClientRegistration clientRegistration = clientRegistration(grantRequest);
99108
headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
100109
headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON));
@@ -104,6 +113,7 @@ private void populateTokenRequestHeaders(T grantRequest, HttpHeaders headers) {
104113
String clientSecret = encodeClientCredential(clientRegistration.getClientSecret());
105114
headers.setBasicAuth(clientId, clientSecret);
106115
}
116+
return headers;
107117
}
108118

109119
private static String encodeClientCredential(String clientCredential) {
@@ -230,4 +240,55 @@ public void setWebClient(WebClient webClient) {
230240
this.webClient = webClient;
231241
}
232242

243+
/**
244+
* Returns the {@link Converter} used for converting the
245+
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders}
246+
* used in the OAuth 2.0 Access Token Request headers.
247+
* @return the {@link Converter} used for converting the
248+
* {@link AbstractOAuth2AuthorizationGrantRequest} to {@link HttpHeaders}
249+
*/
250+
final Converter<T, HttpHeaders> getHeadersConverter() {
251+
return this.headersConverter;
252+
}
253+
254+
/**
255+
* Sets the {@link Converter} used for converting the
256+
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders}
257+
* used in the OAuth 2.0 Access Token Request headers.
258+
* @param headersConverter the {@link Converter} used for converting the
259+
* {@link AbstractOAuth2AuthorizationGrantRequest} to {@link HttpHeaders}
260+
* @since 5.6
261+
*/
262+
public final void setHeadersConverter(Converter<T, HttpHeaders> headersConverter) {
263+
Assert.notNull(headersConverter, "headersConverter cannot be null");
264+
this.headersConverter = headersConverter;
265+
}
266+
267+
/**
268+
* Add (compose) the provided {@code headersConverter} to the current
269+
* {@link Converter} used for converting the
270+
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders}
271+
* used in the OAuth 2.0 Access Token Request headers.
272+
* @param headersConverter the {@link Converter} to add (compose) to the current
273+
* {@link Converter} used for converting the
274+
* {@link AbstractOAuth2AuthorizationGrantRequest} to a {@link HttpHeaders}
275+
* @since 5.6
276+
*/
277+
public final void addHeadersConverter(Converter<T, HttpHeaders> headersConverter) {
278+
Assert.notNull(headersConverter, "headersConverter cannot be null");
279+
Converter<T, HttpHeaders> currentHeadersConverter = this.headersConverter;
280+
this.headersConverter = (authorizationGrantRequest) -> {
281+
// Append headers using a Composite Converter
282+
HttpHeaders headers = currentHeadersConverter.convert(authorizationGrantRequest);
283+
if (headers == null) {
284+
headers = new HttpHeaders();
285+
}
286+
HttpHeaders headersToAdd = headersConverter.convert(authorizationGrantRequest);
287+
if (headersToAdd != null) {
288+
headers.addAll(headersToAdd);
289+
}
290+
return headers;
291+
};
292+
}
293+
233294
}

Diff for: oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java

+65-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2020 the original author or authors.
2+
* Copyright 2002-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,15 +17,18 @@
1717
package org.springframework.security.oauth2.client.endpoint;
1818

1919
import java.time.Instant;
20+
import java.util.Collections;
2021
import java.util.HashMap;
2122
import java.util.Map;
2223

2324
import okhttp3.mockwebserver.MockResponse;
2425
import okhttp3.mockwebserver.MockWebServer;
26+
import okhttp3.mockwebserver.RecordedRequest;
2527
import org.junit.jupiter.api.AfterEach;
2628
import org.junit.jupiter.api.BeforeEach;
2729
import org.junit.jupiter.api.Test;
2830

31+
import org.springframework.core.convert.converter.Converter;
2932
import org.springframework.http.HttpHeaders;
3033
import org.springframework.http.HttpStatus;
3134
import org.springframework.http.MediaType;
@@ -340,4 +343,65 @@ private OAuth2AuthorizationCodeGrantRequest pkceAuthorizationCodeGrantRequest()
340343
return new OAuth2AuthorizationCodeGrantRequest(registration, authorizationExchange);
341344
}
342345

346+
@Test
347+
public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() {
348+
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setHeadersConverter(null))
349+
.withMessage("headersConverter cannot be null");
350+
}
351+
352+
@Test
353+
public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() {
354+
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.addHeadersConverter(null))
355+
.withMessage("headersConverter cannot be null");
356+
}
357+
358+
@Test
359+
public void convertWhenHeadersConverterAddedThenCalled() throws Exception {
360+
OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest();
361+
Converter<OAuth2AuthorizationCodeGrantRequest, HttpHeaders> addedHeadersConverter = mock(Converter.class);
362+
final HttpHeaders headers = new HttpHeaders();
363+
headers.put("CUSTOM_AUTHORIZATION", Collections.singletonList("Basic CUSTOM"));
364+
given(addedHeadersConverter.convert(request)).willReturn(headers);
365+
this.tokenResponseClient.addHeadersConverter(addedHeadersConverter);
366+
// @formatter:off
367+
String accessTokenSuccessResponse = "{\n"
368+
+ " \"access_token\": \"access-token-1234\",\n"
369+
+ " \"token_type\": \"bearer\",\n"
370+
+ " \"expires_in\": \"3600\",\n"
371+
+ " \"scope\": \"openid profile\"\n"
372+
+ "}\n";
373+
// @formatter:on
374+
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
375+
this.tokenResponseClient.getTokenResponse(request).block();
376+
verify(addedHeadersConverter).convert(request);
377+
RecordedRequest actualRequest = this.server.takeRequest();
378+
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION))
379+
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
380+
assertThat(actualRequest.getHeader("CUSTOM_AUTHORIZATION")).isEqualTo("Basic CUSTOM");
381+
}
382+
383+
@Test
384+
public void convertWhenHeadersConverterSetThenCalled() throws Exception {
385+
OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest();
386+
Converter<OAuth2AuthorizationCodeGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
387+
final HttpHeaders headers = new HttpHeaders();
388+
headers.put(HttpHeaders.AUTHORIZATION, Collections.singletonList("Basic CUSTOM"));
389+
given(headersConverter.convert(request)).willReturn(headers);
390+
this.tokenResponseClient.setHeadersConverter(headersConverter);
391+
// @formatter:off
392+
String accessTokenSuccessResponse = "{\n"
393+
+ " \"access_token\": \"access-token-1234\",\n"
394+
+ " \"token_type\": \"bearer\",\n"
395+
+ " \"expires_in\": \"3600\",\n"
396+
+ " \"scope\": \"openid profile\"\n"
397+
+ "}\n";
398+
// @formatter:on
399+
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
400+
this.tokenResponseClient.getTokenResponse(request).block();
401+
verify(headersConverter).convert(request);
402+
RecordedRequest actualRequest = this.server.takeRequest();
403+
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic CUSTOM");
404+
405+
}
406+
343407
}

Diff for: oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java

+62
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.net.URLEncoder;
2020
import java.nio.charset.StandardCharsets;
2121
import java.util.Base64;
22+
import java.util.Collections;
2223

2324
import okhttp3.mockwebserver.MockResponse;
2425
import okhttp3.mockwebserver.MockWebServer;
@@ -27,6 +28,7 @@
2728
import org.junit.jupiter.api.BeforeEach;
2829
import org.junit.jupiter.api.Test;
2930

31+
import org.springframework.core.convert.converter.Converter;
3032
import org.springframework.http.HttpHeaders;
3133
import org.springframework.http.MediaType;
3234
import org.springframework.security.oauth2.client.registration.ClientRegistration;
@@ -212,4 +214,64 @@ private void enqueueJson(String body) {
212214
this.server.enqueue(response);
213215
}
214216

217+
@Test
218+
public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() {
219+
assertThatIllegalArgumentException().isThrownBy(() -> this.client.setHeadersConverter(null))
220+
.withMessage("headersConverter cannot be null");
221+
}
222+
223+
@Test
224+
public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() {
225+
assertThatIllegalArgumentException().isThrownBy(() -> this.client.addHeadersConverter(null))
226+
.withMessage("headersConverter cannot be null");
227+
}
228+
229+
@Test
230+
public void convertWhenHeadersConverterAddedThenCalled() throws Exception {
231+
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(
232+
this.clientRegistration.build());
233+
Converter<OAuth2ClientCredentialsGrantRequest, HttpHeaders> addedHeadersConverter = mock(Converter.class);
234+
final HttpHeaders headers = new HttpHeaders();
235+
headers.put("CUSTOM_AUTHORIZATION", Collections.singletonList("Basic CUSTOM"));
236+
given(addedHeadersConverter.convert(request)).willReturn(headers);
237+
this.client.addHeadersConverter(addedHeadersConverter);
238+
// @formatter:off
239+
enqueueJson("{\n"
240+
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
241+
+ " \"token_type\":\"bearer\",\n"
242+
+ " \"expires_in\":3600,\n"
243+
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
244+
+ "}");
245+
// @formatter:on
246+
this.client.getTokenResponse(request).block();
247+
verify(addedHeadersConverter).convert(request);
248+
RecordedRequest actualRequest = this.server.takeRequest();
249+
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION))
250+
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
251+
assertThat(actualRequest.getHeader("CUSTOM_AUTHORIZATION")).isEqualTo("Basic CUSTOM");
252+
}
253+
254+
@Test
255+
public void convertWhenHeadersConverterSetThenCalled() throws Exception {
256+
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(
257+
this.clientRegistration.build());
258+
Converter<OAuth2ClientCredentialsGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
259+
final HttpHeaders headers = new HttpHeaders();
260+
headers.put(HttpHeaders.AUTHORIZATION, Collections.singletonList("Basic CUSTOM"));
261+
given(headersConverter.convert(request)).willReturn(headers);
262+
this.client.setHeadersConverter(headersConverter);
263+
// @formatter:off
264+
enqueueJson("{\n"
265+
+ " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
266+
+ " \"token_type\":\"bearer\",\n"
267+
+ " \"expires_in\":3600,\n"
268+
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
269+
+ "}");
270+
// @formatter:on
271+
this.client.getTokenResponse(request).block();
272+
verify(headersConverter).convert(request);
273+
RecordedRequest actualRequest = this.server.takeRequest();
274+
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic CUSTOM");
275+
}
276+
215277
}

Diff for: oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClientTests.java

+68-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2020 the original author or authors.
2+
* Copyright 2002-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,6 +17,7 @@
1717
package org.springframework.security.oauth2.client.endpoint;
1818

1919
import java.time.Instant;
20+
import java.util.Collections;
2021

2122
import okhttp3.mockwebserver.MockResponse;
2223
import okhttp3.mockwebserver.MockWebServer;
@@ -25,6 +26,7 @@
2526
import org.junit.jupiter.api.BeforeEach;
2627
import org.junit.jupiter.api.Test;
2728

29+
import org.springframework.core.convert.converter.Converter;
2830
import org.springframework.http.HttpHeaders;
2931
import org.springframework.http.HttpMethod;
3032
import org.springframework.http.MediaType;
@@ -38,6 +40,9 @@
3840
import static org.assertj.core.api.Assertions.assertThat;
3941
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
4042
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
43+
import static org.mockito.BDDMockito.given;
44+
import static org.mockito.Mockito.mock;
45+
import static org.mockito.Mockito.verify;
4146

4247
/**
4348
* Tests for {@link WebClientReactivePasswordTokenResponseClient}.
@@ -213,4 +218,66 @@ private MockResponse jsonResponse(String json) {
213218
// @formatter:on
214219
}
215220

221+
@Test
222+
public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() {
223+
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setHeadersConverter(null))
224+
.withMessage("headersConverter cannot be null");
225+
}
226+
227+
@Test
228+
public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() {
229+
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.addHeadersConverter(null))
230+
.withMessage("headersConverter cannot be null");
231+
}
232+
233+
@Test
234+
public void convertWhenHeadersConverterAddedThenCalled() throws Exception {
235+
OAuth2PasswordGrantRequest request = new OAuth2PasswordGrantRequest(this.clientRegistrationBuilder.build(),
236+
this.username, this.password);
237+
Converter<OAuth2PasswordGrantRequest, HttpHeaders> addedHeadersConverter = mock(Converter.class);
238+
final HttpHeaders headers = new HttpHeaders();
239+
headers.put("CUSTOM_AUTHORIZATION", Collections.singletonList("Basic CUSTOM"));
240+
given(addedHeadersConverter.convert(request)).willReturn(headers);
241+
this.tokenResponseClient.addHeadersConverter(addedHeadersConverter);
242+
// @formatter:off
243+
String accessTokenSuccessResponse = "{\n"
244+
+ " \"access_token\": \"access-token-1234\",\n"
245+
+ " \"token_type\": \"bearer\",\n"
246+
+ " \"expires_in\": \"3600\",\n"
247+
+ " \"scope\": \"read\"\n"
248+
+ "}\n";
249+
// @formatter:on
250+
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
251+
this.tokenResponseClient.getTokenResponse(request).block();
252+
verify(addedHeadersConverter).convert(request);
253+
RecordedRequest actualRequest = this.server.takeRequest();
254+
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION))
255+
.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
256+
assertThat(actualRequest.getHeader("CUSTOM_AUTHORIZATION")).isEqualTo("Basic CUSTOM");
257+
}
258+
259+
@Test
260+
public void convertWhenHeadersConverterSetThenCalled() throws Exception {
261+
OAuth2PasswordGrantRequest request = new OAuth2PasswordGrantRequest(this.clientRegistrationBuilder.build(),
262+
this.username, this.password);
263+
Converter<OAuth2PasswordGrantRequest, HttpHeaders> headersConverter = mock(Converter.class);
264+
final HttpHeaders headers = new HttpHeaders();
265+
headers.put(HttpHeaders.AUTHORIZATION, Collections.singletonList("Basic CUSTOM"));
266+
given(headersConverter.convert(request)).willReturn(headers);
267+
this.tokenResponseClient.setHeadersConverter(headersConverter);
268+
// @formatter:off
269+
String accessTokenSuccessResponse = "{\n"
270+
+ " \"access_token\": \"access-token-1234\",\n"
271+
+ " \"token_type\": \"bearer\",\n"
272+
+ " \"expires_in\": \"3600\",\n"
273+
+ " \"scope\": \"read\"\n"
274+
+ "}\n";
275+
// @formatter:on
276+
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
277+
this.tokenResponseClient.getTokenResponse(request).block();
278+
verify(headersConverter).convert(request);
279+
RecordedRequest actualRequest = this.server.takeRequest();
280+
assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic CUSTOM");
281+
}
282+
216283
}

0 commit comments

Comments
 (0)