|
1 | 1 | /*
|
2 |
| - * Copyright 2002-2022 the original author or authors. |
| 2 | + * Copyright 2002-2024 the original author or authors. |
3 | 3 | *
|
4 | 4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | 5 | * you may not use this file except in compliance with the License.
|
|
17 | 17 | package org.springframework.security.config.web.server;
|
18 | 18 |
|
19 | 19 | import java.net.URI;
|
| 20 | +import java.util.Set; |
20 | 21 |
|
21 | 22 | import org.junit.jupiter.api.Test;
|
22 | 23 | import org.junit.jupiter.api.extension.ExtendWith;
|
| 24 | +import org.mockito.ArgumentCaptor; |
23 | 25 | import reactor.core.publisher.Mono;
|
24 | 26 |
|
25 | 27 | import org.springframework.beans.factory.annotation.Autowired;
|
|
31 | 33 | import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity;
|
32 | 34 | import org.springframework.security.config.test.SpringTestContext;
|
33 | 35 | import org.springframework.security.config.test.SpringTestContextExtension;
|
| 36 | +import org.springframework.security.core.Authentication; |
34 | 37 | import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
35 | 38 | import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
|
36 | 39 | import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
|
| 40 | +import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; |
| 41 | +import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; |
37 | 42 | import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
38 | 43 | import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository;
|
39 | 44 | import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
|
40 | 45 | import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
|
41 | 46 | import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
|
42 | 47 | import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver;
|
43 | 48 | import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
|
| 49 | +import org.springframework.security.oauth2.core.AuthorizationGrantType; |
44 | 50 | import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
45 | 51 | import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
|
| 52 | +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; |
46 | 53 | import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
|
47 | 54 | import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
48 | 55 | import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
|
|
59 | 66 | import org.springframework.web.bind.annotation.GetMapping;
|
60 | 67 | import org.springframework.web.bind.annotation.RestController;
|
61 | 68 | import org.springframework.web.reactive.config.EnableWebFlux;
|
| 69 | +import org.springframework.web.server.ServerWebExchange; |
62 | 70 |
|
| 71 | +import static org.assertj.core.api.Assertions.assertThat; |
63 | 72 | import static org.mockito.ArgumentMatchers.any;
|
64 | 73 | import static org.mockito.BDDMockito.given;
|
65 | 74 | import static org.mockito.Mockito.mock;
|
@@ -215,6 +224,62 @@ public void oauth2ClientWhenCustomObjectsInLambdaThenUsed() {
|
215 | 224 | verify(requestCache).getRedirectUri(any());
|
216 | 225 | }
|
217 | 226 |
|
| 227 | + @Test |
| 228 | + @SuppressWarnings("unchecked") |
| 229 | + public void oauth2ClientWhenCustomAccessTokenResponseClientThenUsed() { |
| 230 | + this.spring.register(OAuth2ClientBeanConfig.class, AuthorizedClientController.class).autowire(); |
| 231 | + ReactiveClientRegistrationRepository clientRegistrationRepository = this.spring.getContext() |
| 232 | + .getBean(ReactiveClientRegistrationRepository.class); |
| 233 | + given(clientRegistrationRepository.findByRegistrationId(any())).willReturn(Mono.just(this.registration)); |
| 234 | + ServerOAuth2AuthorizedClientRepository authorizedClientRepository = this.spring.getContext() |
| 235 | + .getBean(ServerOAuth2AuthorizedClientRepository.class); |
| 236 | + given(authorizedClientRepository.saveAuthorizedClient(any(OAuth2AuthorizedClient.class), |
| 237 | + any(Authentication.class), any(ServerWebExchange.class))) |
| 238 | + .willReturn(Mono.empty()); |
| 239 | + ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = this.spring |
| 240 | + .getContext() |
| 241 | + .getBean(ServerAuthorizationRequestRepository.class); |
| 242 | + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() |
| 243 | + .redirectUri("/authorize/oauth2/code/registration-id") |
| 244 | + .build(); |
| 245 | + given(authorizationRequestRepository.loadAuthorizationRequest(any(ServerWebExchange.class))) |
| 246 | + .willReturn(Mono.just(authorizationRequest)); |
| 247 | + given(authorizationRequestRepository.removeAuthorizationRequest(any(ServerWebExchange.class))) |
| 248 | + .willReturn(Mono.just(authorizationRequest)); |
| 249 | + ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient = this.spring |
| 250 | + .getContext() |
| 251 | + .getBean(ReactiveOAuth2AccessTokenResponseClient.class); |
| 252 | + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("token") |
| 253 | + .tokenType(OAuth2AccessToken.TokenType.BEARER) |
| 254 | + .scopes(Set.of()) |
| 255 | + .expiresIn(300) |
| 256 | + .build(); |
| 257 | + given(accessTokenResponseClient.getTokenResponse(any(OAuth2AuthorizationCodeGrantRequest.class))) |
| 258 | + .willReturn(Mono.just(accessTokenResponse)); |
| 259 | + // @formatter:off |
| 260 | + this.client.get() |
| 261 | + .uri((uriBuilder) -> uriBuilder |
| 262 | + .path("/authorize/oauth2/code/registration-id") |
| 263 | + .queryParam(OAuth2ParameterNames.CODE, "code") |
| 264 | + .queryParam(OAuth2ParameterNames.STATE, "state") |
| 265 | + .build() |
| 266 | + ) |
| 267 | + .exchange() |
| 268 | + .expectStatus().is3xxRedirection(); |
| 269 | + // @formatter:on |
| 270 | + ArgumentCaptor<OAuth2AuthorizationCodeGrantRequest> grantRequestArgumentCaptor = ArgumentCaptor |
| 271 | + .forClass(OAuth2AuthorizationCodeGrantRequest.class); |
| 272 | + verify(accessTokenResponseClient).getTokenResponse(grantRequestArgumentCaptor.capture()); |
| 273 | + OAuth2AuthorizationCodeGrantRequest grantRequest = grantRequestArgumentCaptor.getValue(); |
| 274 | + assertThat(grantRequest.getClientRegistration()).isEqualTo(this.registration); |
| 275 | + assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); |
| 276 | + assertThat(grantRequest.getAuthorizationExchange().getAuthorizationRequest()).isEqualTo(authorizationRequest); |
| 277 | + assertThat(grantRequest.getAuthorizationExchange().getAuthorizationResponse().getCode()).isEqualTo("code"); |
| 278 | + assertThat(grantRequest.getAuthorizationExchange().getAuthorizationResponse().getState()).isEqualTo("state"); |
| 279 | + assertThat(grantRequest.getAuthorizationExchange().getAuthorizationResponse().getRedirectUri()) |
| 280 | + .startsWith("/authorize/oauth2/code/registration-id"); |
| 281 | + } |
| 282 | + |
218 | 283 | @Configuration
|
219 | 284 | @EnableWebFlux
|
220 | 285 | @EnableWebFluxSecurity
|
@@ -324,4 +389,44 @@ SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) {
|
324 | 389 |
|
325 | 390 | }
|
326 | 391 |
|
| 392 | + @Configuration |
| 393 | + @EnableWebFlux |
| 394 | + @EnableWebFluxSecurity |
| 395 | + static class OAuth2ClientBeanConfig { |
| 396 | + |
| 397 | + @Bean |
| 398 | + SecurityWebFilterChain securityWebFilterChain(ServerHttpSecurity http) { |
| 399 | + // @formatter:off |
| 400 | + http |
| 401 | + .oauth2Client((oauth2Client) -> oauth2Client |
| 402 | + .authorizationRequestRepository(authorizationRequestRepository()) |
| 403 | + ); |
| 404 | + // @formatter:on |
| 405 | + return http.build(); |
| 406 | + } |
| 407 | + |
| 408 | + @Bean |
| 409 | + @SuppressWarnings("unchecked") |
| 410 | + ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository() { |
| 411 | + return mock(ServerAuthorizationRequestRepository.class); |
| 412 | + } |
| 413 | + |
| 414 | + @Bean |
| 415 | + @SuppressWarnings("unchecked") |
| 416 | + ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> authorizationCodeAccessTokenResponseClient() { |
| 417 | + return mock(ReactiveOAuth2AccessTokenResponseClient.class); |
| 418 | + } |
| 419 | + |
| 420 | + @Bean |
| 421 | + ReactiveClientRegistrationRepository clientRegistrationRepository() { |
| 422 | + return mock(ReactiveClientRegistrationRepository.class); |
| 423 | + } |
| 424 | + |
| 425 | + @Bean |
| 426 | + ServerOAuth2AuthorizedClientRepository authorizedClientRepository() { |
| 427 | + return mock(ServerOAuth2AuthorizedClientRepository.class); |
| 428 | + } |
| 429 | + |
| 430 | + } |
| 431 | + |
327 | 432 | }
|
0 commit comments