From 101076083fd6e40d866d0ac1567f0c567a85ed8f Mon Sep 17 00:00:00 2001 From: Daniel Garnier-Moiroux Date: Thu, 27 Oct 2022 14:11:11 +0200 Subject: [PATCH] Improve customizing OIDC Client Registration endpoint --- .../src/docs/asciidoc/protocol-endpoints.adoc | 18 +- ...cClientRegistrationEndpointConfigurer.java | 170 ++++++++++++++++-- .../OidcClientRegistrationEndpointFilter.java | 78 ++++++-- .../OidcClientRegistrationTests.java | 132 ++++++++++++++ ...ClientRegistrationEndpointFilterTests.java | 161 +++++++++++++---- 5 files changed, 498 insertions(+), 61 deletions(-) diff --git a/docs/src/docs/asciidoc/protocol-endpoints.adoc b/docs/src/docs/asciidoc/protocol-endpoints.adoc index 563471195..b745e047f 100644 --- a/docs/src/docs/asciidoc/protocol-endpoints.adoc +++ b/docs/src/docs/asciidoc/protocol-endpoints.adoc @@ -351,12 +351,26 @@ public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity h authorizationServerConfigurer .oidc(oidc -> oidc - .clientRegistrationEndpoint(Customizer.withDefaults()) + .clientRegistrationEndpoint(clientRegistrationEndpoint -> + clientRegistrationEndpoint + .clientRegistrationRequestConverter(clientRegistrationRequestConverter) <1> + .clientRegistrationRequestConverters(clientRegistrationRequestConvertersConsumers) <2> + .authenticationProvider(authenticationProvider) <3> + .authenticationProviders(authenticationProvidersConsumer) <4> + .clientRegistrationResponseHandler(clientRegistrationResponseHandler) <5> + .errorResponseHandler(errorResponseHandler) <6> + ) ); return http.build(); } ---- +<1> `clientRegistrationRequestConverter()`: Adds an `AuthenticationConverter` (_pre-processor_) used when attempting to extract a https://openid.net/specs/openid-connect-registration-1_0.html#RegistrationRequest[Client Registration Request] or https://openid.net/specs/openid-connect-registration-1_0.html#ReadRequest[Client Read Request] from `HttpServletRequest` to an instance of `OidcClientRegistrationAuthenticationToken`. +<2> `clientRegistrationRequestConverters()`: Sets the `Consumer` providing access to the `List` of default and (optionally) added ``AuthenticationConverter``'s allowing the ability to add, remove, or customize a specific `AuthenticationConverter`. +<3> `authenticationProvider()`: Adds an `AuthenticationProvider` (_main processor_) used for authenticating the `OidcClientRegistrationAuthenticationToken`. +<4> `authenticationProviders()`: Sets the `Consumer` providing access to the `List` of default and (optionally) added ``AuthenticationProvider``'s allowing the ability to add, remove, or customize a specific `AuthenticationProvider`. +<5> `clientRegistrationResponseHandler()`: The `AuthenticationSuccessHandler` (_post-processor_) used for handling an "`authenticated`" `OidcClientRegistrationAuthenticationToken` and returning the https://openid.net/specs/openid-connect-registration-1_0.html#RegistrationResponse[Client Registration Response] or https://openid.net/specs/openid-connect-registration-1_0.html#ReadResponse[Client Read Response]. +<6> `errorResponseHandler()`: The `AuthenticationFailureHandler` (_post-processor_) used for handling an `OAuth2AuthenticationException` and returning the https://openid.net/specs/openid-connect-registration-1_0.html#RegistrationError[Client Registration Error Response] or https://openid.net/specs/openid-connect-registration-1_0.html#ReadError[Client Read Error Response]. [NOTE] The OpenID Connect 1.0 Client Registration endpoint is disabled by default because many deployments do not require dynamic client registration. @@ -371,6 +385,8 @@ The OpenID Connect 1.0 Client Registration endpoint is disabled by default becau * `*AuthenticationConverter*` -- An `OidcClientRegistrationAuthenticationConverter`. * `*AuthenticationManager*` -- An `AuthenticationManager` composed of `OidcClientRegistrationAuthenticationProvider` and `OidcClientConfigurationAuthenticationProvider`. +* `*AuthenticationSuccessHandler*` -- An internal implementation that handles an "`authenticated`" `OidcClientRegistrationAuthenticationToken` and returns the Client Registration or Client Read response. +* `*AuthenticationFailureHandler*` -- An internal implementation that uses the `OAuth2Error` associated with the `OAuth2AuthenticationException` and returns the `OAuth2Error` response. The OpenID Connect 1.0 Client Registration endpoint is an https://openid.net/specs/openid-connect-registration-1_0.html#ClientRegistration[OAuth2 protected resource], which *REQUIRES* an access token to be sent as a bearer token in the Client Registration (or Client Read) request. diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcClientRegistrationEndpointConfigurer.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcClientRegistrationEndpointConfigurer.java index d631f85a5..a8bf0c95c 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcClientRegistrationEndpointConfigurer.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcClientRegistrationEndpointConfigurer.java @@ -15,29 +15,53 @@ */ package org.springframework.security.oauth2.server.authorization.config.annotation.web.configurers; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; + +import javax.servlet.http.HttpServletRequest; + import org.springframework.http.HttpMethod; import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.config.annotation.ObjectPostProcessor; import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.oidc.OidcUserInfo; import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientConfigurationAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationProvider; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationToken; import org.springframework.security.oauth2.server.authorization.oidc.web.OidcClientRegistrationEndpointFilter; +import org.springframework.security.oauth2.server.authorization.oidc.web.authentication.OidcClientRegistrationAuthenticationConverter; import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; +import org.springframework.security.oauth2.server.authorization.web.authentication.DelegatingAuthenticationConverter; import org.springframework.security.web.access.intercept.FilterSecurityInterceptor; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.OrRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.Assert; /** * Configurer for OpenID Connect Dynamic Client Registration 1.0 Endpoint. * * @author Joe Grandja + * @author Daniel Garnier-Moiroux * @since 0.2.0 * @see OidcConfigurer#clientRegistrationEndpoint * @see OidcClientRegistrationEndpointFilter */ public final class OidcClientRegistrationEndpointConfigurer extends AbstractOAuth2Configurer { private RequestMatcher requestMatcher; + private final List clientRegistrationRequestConverters = new ArrayList<>(); + private Consumer> clientRegistrationRequestConvertersConsumer = (authenticationConverters) -> {}; + private final List authenticationProviders = new ArrayList<>(); + private Consumer> authenticationProvidersConsumer = (authenticationProviders) -> {}; + private AuthenticationSuccessHandler clientRegistrationResponseHandler; + private AuthenticationFailureHandler errorResponseHandler; /** * Restrict for internal use only. @@ -46,6 +70,93 @@ public final class OidcClientRegistrationEndpointConfigurer extends AbstractOAut super(objectPostProcessor); } + /** + * Sets the {@link AuthenticationConverter} used when attempting to extract the OIDC Client Registration Request + * from {@link HttpServletRequest} to an instance of {@link OidcClientRegistrationAuthenticationToken} used for + * creating the Client Registration or returning the Client Read Response. + * + * @param clientRegistrationRequestConverter the {@link AuthenticationConverter} used when attempting to extract an + * OIDC Client Registration Request from {@link HttpServletRequest} + * @return the {@link OidcClientRegistrationEndpointConfigurer} for further configuration + * @since 0.4.0 + */ + public OidcClientRegistrationEndpointConfigurer clientRegistrationRequestConverter( + AuthenticationConverter clientRegistrationRequestConverter) { + Assert.notNull(clientRegistrationRequestConverter, "clientRegistrationRequestConverter cannot be null"); + this.clientRegistrationRequestConverters.add(clientRegistrationRequestConverter); + return this; + } + + /** + * Sets the {@code Consumer} providing access to the {@code List} of default + * and (optionally) added {@link #clientRegistrationRequestConverter(AuthenticationConverter) AuthenticationConverter}'s + * allowing the ability to add, remove, or customize a specific {@link AuthenticationConverter}. + * + * @param clientRegistrationRequestConvertersConsumer the {@code Consumer} providing access to the {@code List} of default and (optionally) added {@link AuthenticationConverter}'s + * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration + * @since 0.4.0 + */ + public OidcClientRegistrationEndpointConfigurer clientRegistrationRequestConverters(Consumer> clientRegistrationRequestConvertersConsumer) { + Assert.notNull(clientRegistrationRequestConvertersConsumer, "clientRegistrationRequestConvertersConsumer cannot be null"); + this.clientRegistrationRequestConvertersConsumer = clientRegistrationRequestConvertersConsumer; + return this; + } + + /** + * Adds an {@link AuthenticationProvider} used for authenticating a type of {@link OidcClientRegistrationAuthenticationToken}. + * + * @param authenticationProvider a {@link AuthenticationProvider} used for authenticating a type of {@link OidcClientRegistrationAuthenticationToken} + * @return the {@link OidcClientRegistrationEndpointConfigurer} for further configuration + * @since 0.4.0 + */ + public OidcClientRegistrationEndpointConfigurer authenticationProvider(AuthenticationProvider authenticationProvider) { + Assert.notNull(authenticationProvider, "authenticationProvider cannot be null"); + this.authenticationProviders.add(authenticationProvider); + return this; + } + + /** + * Sets the {@code Consumer} providing access to the {@code List} of default + * and (optionally) added {@link #authenticationProvider(AuthenticationProvider) AuthenticationProvider}'s + * allowing the ability to add, remove, or customize a specific {@link AuthenticationProvider}. + * + * @param authenticationProvidersConsumer the {@code Consumer} providing access to the {@code List} of default and (optionally) added {@link AuthenticationProvider}'s + * @return the {@link OidcClientRegistrationEndpointConfigurer} for further configuration + * @since 0.4.0 + */ + public OidcClientRegistrationEndpointConfigurer authenticationProviders( + Consumer> authenticationProvidersConsumer) { + Assert.notNull(authenticationProvidersConsumer, "authenticationProvidersConsumer cannot be null"); + this.authenticationProvidersConsumer = authenticationProvidersConsumer; + return this; + } + + /** + * Sets the {@link AuthenticationSuccessHandler} used for handling an {@link OidcClientRegistrationAuthenticationToken} and + * returning the {@link OidcUserInfo User Info Response}. + * + * @param clientRegistrationResponseHandler the {@link AuthenticationSuccessHandler} used for handling an {@link OidcClientRegistrationAuthenticationToken} + * @return the {@link OidcClientRegistrationEndpointConfigurer} for further configuration + * @since 0.4.0 + */ + public OidcClientRegistrationEndpointConfigurer clientRegistrationResponseHandler(AuthenticationSuccessHandler clientRegistrationResponseHandler) { + this.clientRegistrationResponseHandler = clientRegistrationResponseHandler; + return this; + } + + /** + * Sets the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException} and + * returning the {@link OAuth2Error Error Response}. + * + * @param errorResponseHandler the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException} + * @return the {@link OidcClientRegistrationEndpointConfigurer} for further configuration + * @since 0.4.0 + */ + public OidcClientRegistrationEndpointConfigurer errorResponseHandler(AuthenticationFailureHandler errorResponseHandler) { + this.errorResponseHandler = errorResponseHandler; + return this; + } + @Override void init(HttpSecurity httpSecurity) { AuthorizationServerSettings authorizationServerSettings = OAuth2ConfigurerUtils.getAuthorizationServerSettings(httpSecurity); @@ -54,18 +165,15 @@ void init(HttpSecurity httpSecurity) { new AntPathRequestMatcher(authorizationServerSettings.getOidcClientRegistrationEndpoint(), HttpMethod.GET.name()) ); - OidcClientRegistrationAuthenticationProvider oidcClientRegistrationAuthenticationProvider = - new OidcClientRegistrationAuthenticationProvider( - OAuth2ConfigurerUtils.getRegisteredClientRepository(httpSecurity), - OAuth2ConfigurerUtils.getAuthorizationService(httpSecurity), - OAuth2ConfigurerUtils.getTokenGenerator(httpSecurity)); - httpSecurity.authenticationProvider(postProcess(oidcClientRegistrationAuthenticationProvider)); + List authenticationProviders = createDefaultAuthenticationProviders(httpSecurity); - OidcClientConfigurationAuthenticationProvider oidcClientConfigurationAuthenticationProvider = - new OidcClientConfigurationAuthenticationProvider( - OAuth2ConfigurerUtils.getRegisteredClientRepository(httpSecurity), - OAuth2ConfigurerUtils.getAuthorizationService(httpSecurity)); - httpSecurity.authenticationProvider(postProcess(oidcClientConfigurationAuthenticationProvider)); + if (!this.authenticationProviders.isEmpty()) { + authenticationProviders.addAll(0, this.authenticationProviders); + } + this.authenticationProvidersConsumer.accept(authenticationProviders); + + authenticationProviders.forEach(authenticationProvider -> + httpSecurity.authenticationProvider(postProcess(authenticationProvider))); } @Override @@ -77,6 +185,22 @@ void configure(HttpSecurity httpSecurity) { new OidcClientRegistrationEndpointFilter( authenticationManager, authorizationServerSettings.getOidcClientRegistrationEndpoint()); + + List authenticationConverters = createDefaultAuthenticationConverters(); + if (!this.clientRegistrationRequestConverters.isEmpty()) { + authenticationConverters.addAll(0, this.clientRegistrationRequestConverters); + } + this.clientRegistrationRequestConvertersConsumer.accept(authenticationConverters); + oidcClientRegistrationEndpointFilter.setAuthenticationConverter( + new DelegatingAuthenticationConverter(authenticationConverters)); + + if (this.clientRegistrationResponseHandler != null) { + oidcClientRegistrationEndpointFilter + .setAuthenticationSuccessHandler(this.clientRegistrationResponseHandler); + } + if (this.errorResponseHandler != null) { + oidcClientRegistrationEndpointFilter.setAuthenticationFailureHandler(this.errorResponseHandler); + } httpSecurity.addFilterAfter(postProcess(oidcClientRegistrationEndpointFilter), FilterSecurityInterceptor.class); } @@ -85,4 +209,28 @@ RequestMatcher getRequestMatcher() { return this.requestMatcher; } + private static List createDefaultAuthenticationProviders(HttpSecurity httpSecurity) { + List authenticationProviders = new ArrayList<>(); + + OidcClientRegistrationAuthenticationProvider oidcClientRegistrationAuthenticationProvider = + new OidcClientRegistrationAuthenticationProvider( + OAuth2ConfigurerUtils.getRegisteredClientRepository(httpSecurity), + OAuth2ConfigurerUtils.getAuthorizationService(httpSecurity), + OAuth2ConfigurerUtils.getTokenGenerator(httpSecurity)); + authenticationProviders.add(oidcClientRegistrationAuthenticationProvider); + + OidcClientConfigurationAuthenticationProvider oidcClientConfigurationAuthenticationProvider = + new OidcClientConfigurationAuthenticationProvider( + OAuth2ConfigurerUtils.getRegisteredClientRepository(httpSecurity), + OAuth2ConfigurerUtils.getAuthorizationService(httpSecurity)); + authenticationProviders.add(oidcClientConfigurationAuthenticationProvider); + return authenticationProviders; + } + + private static List createDefaultAuthenticationConverters() { + List authenticationConverters = new ArrayList<>(); + authenticationConverters.add(new OidcClientRegistrationAuthenticationConverter()); + return authenticationConverters; + } + } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilter.java index af408b707..f737b01d6 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilter.java @@ -27,6 +27,8 @@ import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; @@ -40,6 +42,8 @@ import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcClientRegistrationHttpMessageConverter; import org.springframework.security.oauth2.server.authorization.oidc.web.authentication.OidcClientRegistrationAuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.util.matcher.AndRequestMatcher; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.OrRequestMatcher; @@ -53,6 +57,7 @@ * * @author Ovidiu Popa * @author Joe Grandja + * @author Daniel Garnier-Moiroux * @since 0.1.1 * @see OidcClientRegistration * @see OidcClientRegistrationAuthenticationConverter @@ -69,11 +74,13 @@ public final class OidcClientRegistrationEndpointFilter extends OncePerRequestFi private final AuthenticationManager authenticationManager; private final RequestMatcher clientRegistrationEndpointMatcher; + private AuthenticationConverter authenticationConverter = new OidcClientRegistrationAuthenticationConverter(); private final HttpMessageConverter clientRegistrationHttpMessageConverter = new OidcClientRegistrationHttpMessageConverter(); private final HttpMessageConverter errorHttpResponseConverter = new OAuth2ErrorHttpMessageConverter(); - private AuthenticationConverter authenticationConverter; + private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendClientRegistrationResponse; + private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse; /** * Constructs an {@code OidcClientRegistrationEndpointFilter} using the provided parameters. @@ -99,7 +106,6 @@ public OidcClientRegistrationEndpointFilter(AuthenticationManager authentication new AntPathRequestMatcher( clientRegistrationEndpointUri, HttpMethod.POST.name()), createClientConfigurationMatcher(clientRegistrationEndpointUri)); - this.authenticationConverter = new OidcClientRegistrationAuthenticationConverter(); } private static RequestMatcher createClientConfigurationMatcher(String clientRegistrationEndpointUri) { @@ -130,33 +136,77 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse OidcClientRegistrationAuthenticationToken clientRegistrationAuthenticationResult = (OidcClientRegistrationAuthenticationToken) this.authenticationManager.authenticate(clientRegistrationAuthentication); - HttpStatus httpStatus = HttpStatus.OK; - if (clientRegistrationAuthentication.getClientRegistration() != null) { - httpStatus = HttpStatus.CREATED; - } - - sendClientRegistrationResponse(response, httpStatus, clientRegistrationAuthenticationResult.getClientRegistration()); - + this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, clientRegistrationAuthenticationResult); } catch (OAuth2AuthenticationException ex) { - sendErrorResponse(response, ex.getError()); + this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex); } catch (Exception ex) { OAuth2Error error = new OAuth2Error( OAuth2ErrorCodes.INVALID_REQUEST, "OpenID Client Registration Error: " + ex.getMessage(), "https://openid.net/specs/openid-connect-registration-1_0.html#RegistrationError"); - sendErrorResponse(response, error); + this.authenticationFailureHandler.onAuthenticationFailure(request, response, + new OAuth2AuthenticationException(error)); } finally { SecurityContextHolder.clearContext(); } } - private void sendClientRegistrationResponse(HttpServletResponse response, HttpStatus httpStatus, OidcClientRegistration clientRegistration) throws IOException { + /** + * Sets the {@link AuthenticationConverter} used when attempting to extract the OIDC Client Registration Request + * from {@link HttpServletRequest} to an instance of {@link OidcClientRegistrationAuthenticationToken} used for + * creating the Client Registration or returning the Client Read Response. + * + * @param authenticationConverter the {@link AuthenticationConverter} used when attempting to extract an + * OIDC Client Registration Request from {@link HttpServletRequest} + * @since 0.4.0 + */ + public void setAuthenticationConverter(AuthenticationConverter authenticationConverter) { + Assert.notNull(authenticationConverter, "authenticationConverter cannot be null"); + this.authenticationConverter = authenticationConverter; + } + + /** + * Sets the {@link AuthenticationSuccessHandler} used for handling an {@link OidcClientRegistrationAuthenticationToken} + * and returning the {@link OidcClientRegistration Client Registration Response}. + * + * @param authenticationSuccessHandler the {@link AuthenticationSuccessHandler} used for handling an {@link OidcClientRegistrationAuthenticationToken} + * @see 0.4.0 + */ + public void setAuthenticationSuccessHandler(AuthenticationSuccessHandler authenticationSuccessHandler) { + Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null"); + this.authenticationSuccessHandler = authenticationSuccessHandler; + } + + /** + * Sets the {@link AuthenticationFailureHandler} used for handling an + * {@link OAuth2AuthenticationException} and returning the {@link OAuth2Error Error + * Response}. + * @param authenticationFailureHandler the {@link AuthenticationFailureHandler} used + * for handling an {@link OAuth2AuthenticationException} + * @since 0.4.0 + */ + public void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) { + Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null"); + this.authenticationFailureHandler = authenticationFailureHandler; + } + + private void sendClientRegistrationResponse(HttpServletRequest request, HttpServletResponse response, + Authentication authentication) throws IOException { + OidcClientRegistration clientRegistration = ((OidcClientRegistrationAuthenticationToken) authentication) + .getClientRegistration(); ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); - httpResponse.setStatusCode(httpStatus); + if (HttpMethod.POST.name().equals(request.getMethod())) { + httpResponse.setStatusCode(HttpStatus.CREATED); + } + else { + httpResponse.setStatusCode(HttpStatus.OK); + } this.clientRegistrationHttpMessageConverter.write(clientRegistration, null, httpResponse); } - private void sendErrorResponse(HttpServletResponse response, OAuth2Error error) throws IOException { + private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response, + AuthenticationException authenticationException) throws IOException { + OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError(); HttpStatus httpStatus = HttpStatus.BAD_REQUEST; if (OAuth2ErrorCodes.INVALID_TOKEN.equals(error.getErrorCode())) { httpStatus = HttpStatus.UNAUTHORIZED; diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcClientRegistrationTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcClientRegistrationTests.java index 2fb92bf66..787a143ba 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcClientRegistrationTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcClientRegistrationTests.java @@ -18,6 +18,10 @@ import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.Collections; +import java.util.List; +import java.util.function.Consumer; + +import javax.servlet.http.HttpServletResponse; import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.source.JWKSource; @@ -30,6 +34,7 @@ import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; +import org.mockito.ArgumentCaptor; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; @@ -37,6 +42,7 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.jdbc.core.JdbcOperations; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase; @@ -45,6 +51,7 @@ import org.springframework.mock.http.MockHttpOutputMessage; import org.springframework.mock.http.client.MockClientHttpResponse; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.config.Customizer; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; @@ -54,6 +61,7 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; @@ -76,11 +84,17 @@ import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration; import org.springframework.security.oauth2.server.authorization.oidc.OidcClientRegistration; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientConfigurationAuthenticationProvider; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationProvider; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationToken; import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcClientRegistrationHttpMessageConverter; import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; import org.springframework.security.oauth2.server.authorization.settings.ClientSettings; import org.springframework.security.oauth2.server.authorization.test.SpringTestRule; import org.springframework.security.web.SecurityFilterChain; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; @@ -88,6 +102,14 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.CoreMatchers.containsString; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.jwt; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; @@ -127,6 +149,18 @@ public class OidcClientRegistrationTests { @Autowired private AuthorizationServerSettings authorizationServerSettings; + private static AuthenticationConverter authenticationConverter; + + private static Consumer> authenticationConvertersConsumer; + + private static AuthenticationProvider authenticationProvider; + + private static Consumer> authenticationProvidersConsumer; + + private static AuthenticationSuccessHandler authenticationSuccessHandler; + + private static AuthenticationFailureHandler authenticationFailureHandler; + private MockWebServer server; private String clientJwkSetUrl; @@ -144,6 +178,12 @@ public static void init() { .addScript("org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql") .addScript("org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client-schema.sql") .build(); + authenticationConverter = mock(AuthenticationConverter.class); + authenticationConvertersConsumer = mock(Consumer.class); + authenticationProvider = mock(AuthenticationProvider.class); + authenticationProvidersConsumer = mock(Consumer.class); + authenticationSuccessHandler = mock(AuthenticationSuccessHandler.class); + authenticationFailureHandler = mock(AuthenticationFailureHandler.class); } @Before @@ -157,6 +197,7 @@ public void setup() throws Exception { .setBody(clientJwkSet.toString()); // @formatter:on this.server.enqueue(response); + when(authenticationProvider.supports(OidcClientRegistrationAuthenticationToken.class)).thenReturn(true); } @After @@ -164,6 +205,12 @@ public void tearDown() throws Exception { this.server.shutdown(); jdbcOperations.update("truncate table oauth2_authorization"); jdbcOperations.update("truncate table oauth2_registered_client"); + reset(authenticationConverter); + reset(authenticationConvertersConsumer); + reset(authenticationProvider); + reset(authenticationProvidersConsumer); + reset(authenticationSuccessHandler); + reset(authenticationFailureHandler); } @AfterClass @@ -260,6 +307,65 @@ public void requestWhenClientConfigurationRequestAuthorizedThenClientRegistratio assertThat(clientConfigurationResponse.getRegistrationAccessToken()).isNull(); } + @Test + public void requestWhenUserInfoEndpointCustomizedThenUsed() throws Exception { + this.spring.register(CustomClientRegistrationConfiguration.class).autowire(); + + // @formatter:off + OidcClientRegistration clientRegistration = OidcClientRegistration.builder() + .clientName("client-name") + .redirectUri("https://client.example.com") + .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) + .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .scope("scope1") + .scope("scope2") + .build(); + // @formatter:on + + doAnswer(invocation -> { + HttpServletResponse response = invocation.getArgument(1, HttpServletResponse.class); + ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); + httpResponse.setStatusCode(HttpStatus.CREATED); + new OidcClientRegistrationHttpMessageConverter().write(clientRegistration, null, httpResponse); + return null; + }).when(authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), any()); + + registerClient(clientRegistration); + + verify(authenticationConverter).convert(any()); + ArgumentCaptor> authenticationConvertersCaptor = ArgumentCaptor + .forClass(List.class); + verify(authenticationConvertersConsumer).accept(authenticationConvertersCaptor.capture()); + List authenticationConverters = authenticationConvertersCaptor.getValue(); + assertThat(authenticationConverters).hasSize(2).contains(authenticationConverter); + + verify(authenticationProvider).authenticate(any()); + ArgumentCaptor> authenticationProvidersCaptor = ArgumentCaptor + .forClass(List.class); + verify(authenticationProvidersConsumer).accept(authenticationProvidersCaptor.capture()); + List authenticationProviders = authenticationProvidersCaptor.getValue(); + assertThat(authenticationProviders).hasSize(3) + .allMatch(provider -> provider == authenticationProvider + || provider instanceof OidcClientRegistrationAuthenticationProvider + || provider instanceof OidcClientConfigurationAuthenticationProvider); + + verify(authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), any()); + verifyNoInteractions(authenticationFailureHandler); + } + + @Test + public void requestWhenUserInfoEndpointCustomizedAndErrorThenUsed() throws Exception { + this.spring.register(CustomClientRegistrationConfiguration.class).autowire(); + + when(authenticationProvider.authenticate(any())).thenThrow(new OAuth2AuthenticationException("error")); + + this.mvc.perform(get(DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI) + .param(OAuth2ParameterNames.CLIENT_ID, "invalid").with(jwt())); + + verify(authenticationFailureHandler).onAuthenticationFailure(any(), any(), any()); + verifyNoInteractions(authenticationSuccessHandler); + } + private OidcClientRegistration registerClient(OidcClientRegistration clientRegistration) throws Exception { // ***** (1) Obtain the "initial" access token used for registering the client @@ -352,6 +458,32 @@ private static OidcClientRegistration readClientRegistrationResponse(MockHttpSer return clientRegistrationHttpMessageConverter.read(OidcClientRegistration.class, httpResponse); } + @EnableWebSecurity + static class CustomClientRegistrationConfiguration extends AuthorizationServerConfiguration { + + @Bean + @Override + public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity http) throws Exception { + OAuth2AuthorizationServerConfigurer authorizationServerConfigurer = new OAuth2AuthorizationServerConfigurer(); + authorizationServerConfigurer.oidc(oidc -> oidc.clientRegistrationEndpoint( + clientRegistration -> clientRegistration.clientRegistrationRequestConverter(authenticationConverter) + .clientRegistrationRequestConverters(authenticationConvertersConsumer) + .authenticationProvider(authenticationProvider) + .authenticationProviders(authenticationProvidersConsumer) + .clientRegistrationResponseHandler(authenticationSuccessHandler) + .errorResponseHandler(authenticationFailureHandler))); + RequestMatcher endpointsMatcher = authorizationServerConfigurer.getEndpointsMatcher(); + + http.requestMatcher(endpointsMatcher) + .authorizeRequests(authorizeRequests -> authorizeRequests.anyRequest().authenticated()) + .csrf(csrf -> csrf.ignoringRequestMatchers(endpointsMatcher)) + .oauth2ResourceServer(OAuth2ResourceServerConfigurer::jwt).apply(authorizationServerConfigurer); + return http.build(); + + } + + } + @EnableWebSecurity static class AuthorizationServerConfiguration { diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java index 5a25a5870..f6de9fe96 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java @@ -15,10 +15,12 @@ */ package org.springframework.security.oauth2.server.authorization.oidc.web; +import java.io.IOException; import java.time.Instant; import java.util.Collections; import javax.servlet.FilterChain; +import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -33,6 +35,8 @@ import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; @@ -54,10 +58,14 @@ import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationToken; import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcClientRegistrationHttpMessageConverter; import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; @@ -68,6 +76,7 @@ * * @author Ovidiu Popa * @author Joe Grandja + * @author Daniel Garnier-Moiroux */ public class OidcClientRegistrationEndpointFilterTests { private static final String DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI = "/connect/register"; @@ -103,6 +112,27 @@ public void constructorWhenClientRegistrationEndpointUriNullThenThrowIllegalArgu .withMessage("clientRegistrationEndpointUri cannot be empty"); } + @Test + public void setAuthenticationConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.filter.setAuthenticationConverter(null)) + .withMessage("authenticationConverter cannot be null"); + } + + @Test + public void setAuthenticationSuccessHandlerWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.filter.setAuthenticationSuccessHandler(null)) + .withMessage("authenticationSuccessHandler cannot be null"); + } + + @Test + public void setAuthenticationFailureHandlerWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.filter.setAuthenticationFailureHandler(null)) + .withMessage("authenticationFailureHandler cannot be null"); + } + @Test public void doFilterWhenNotClientRegistrationRequestThenNotProcessed() throws Exception { String requestUri = "/path"; @@ -203,25 +233,13 @@ private void doFilterWhenClientRegistrationRequestInvalidThenError( @Test public void doFilterWhenClientRegistrationRequestValidThenSuccessResponse() throws Exception { // @formatter:off - OidcClientRegistration.Builder clientRegistrationBuilder = OidcClientRegistration.builder() - .clientName("client-name") - .redirectUri("https://client.example.com") - .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) - .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) - .scope("scope1") - .scope("scope2"); + OidcClientRegistration expectedClientRegistrationResponse = createClientRegistration(); - OidcClientRegistration clientRegistrationRequest = clientRegistrationBuilder.build(); - - OidcClientRegistration expectedClientRegistrationResponse = clientRegistrationBuilder - .clientId("client-id") - .clientIdIssuedAt(Instant.now()) - .clientSecret("client-secret") - .tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue()) - .responseType(OAuth2AuthorizationResponseType.CODE.getValue()) - .idTokenSignedResponseAlgorithm(SignatureAlgorithm.RS256.getName()) - .registrationAccessToken("registration-access-token") - .registrationClientUrl("https://auth-server:9000/connect/register?client_id=client-id") + OidcClientRegistration clientRegistrationRequest = OidcClientRegistration.builder() + .clientName(expectedClientRegistrationResponse.getClientName()) + .redirectUris(redirectUris -> redirectUris.addAll(expectedClientRegistrationResponse.getRedirectUris())) + .grantTypes(grantTypes -> grantTypes.addAll(expectedClientRegistrationResponse.getGrantTypes())) + .scopes(scopes -> scopes.addAll(expectedClientRegistrationResponse.getScopes())) .build(); // @formatter:on @@ -353,6 +371,27 @@ public void doFilterWhenClientConfigurationRequestInvalidClientThenUnauthorizedE OAuth2ErrorCodes.INVALID_CLIENT, HttpStatus.UNAUTHORIZED); } + @Test + public void doFilterWhenCustomAuthenticationFailureHandlerThenUsed() throws Exception { + AuthenticationFailureHandler authenticationFailureHandler = mock(AuthenticationFailureHandler.class); + this.filter.setAuthenticationFailureHandler(authenticationFailureHandler); + + when(this.authenticationManager.authenticate(any())) + .thenThrow(new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN)); + + String requestUri = DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + request.setParameter(OAuth2ParameterNames.CLIENT_ID, "client1"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(authenticationFailureHandler).onAuthenticationFailure(eq(request), eq(response), + any(OAuth2AuthenticationException.class)); + } + private void doFilterWhenClientConfigurationRequestInvalidThenError( String errorCode, HttpStatus status) throws Exception { Jwt jwt = createJwt("client.read"); @@ -384,23 +423,7 @@ private void doFilterWhenClientConfigurationRequestInvalidThenError( @Test public void doFilterWhenClientConfigurationRequestValidThenSuccessResponse() throws Exception { - // @formatter:off - OidcClientRegistration expectedClientRegistrationResponse = OidcClientRegistration.builder() - .clientId("client-id") - .clientIdIssuedAt(Instant.now()) - .clientSecret("client-secret") - .clientName("client-name") - .redirectUri("https://client.example.com") - .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) - .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) - .tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue()) - .responseType(OAuth2AuthorizationResponseType.CODE.getValue()) - .idTokenSignedResponseAlgorithm(SignatureAlgorithm.RS256.getName()) - .scope("scope1") - .scope("scope2") - .registrationClientUrl("https://auth-server:9000/connect/register?client_id=client-id") - .build(); - // @formatter:on + OidcClientRegistration expectedClientRegistrationResponse = createClientRegistration(); Jwt jwt = createJwt("client.read"); JwtAuthenticationToken principal = new JwtAuthenticationToken( @@ -452,6 +475,74 @@ public void doFilterWhenClientConfigurationRequestValidThenSuccessResponse() thr .isEqualTo(expectedClientRegistrationResponse.getRegistrationClientUrl()); } + @Test + public void doFilterWhenCustomAuthenticationSuccessHandlerThenUsed() throws Exception { + OidcClientRegistration expectedClientRegistrationResponse = createClientRegistration(); + Authentication principal = new TestingAuthenticationToken("principal", "Credentials"); + + OidcClientRegistrationAuthenticationToken clientRegistrationAuthenticationResult = + new OidcClientRegistrationAuthenticationToken(principal, expectedClientRegistrationResponse); + + when(this.authenticationManager.authenticate(any())).thenReturn(clientRegistrationAuthenticationResult); + AuthenticationSuccessHandler successHandler = mock(AuthenticationSuccessHandler.class); + this.filter.setAuthenticationSuccessHandler(successHandler); + + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(principal); + SecurityContextHolder.setContext(securityContext); + + String requestUri = DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + request.setParameter(OAuth2ParameterNames.CLIENT_ID, expectedClientRegistrationResponse.getClientId()); + + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(successHandler).onAuthenticationSuccess(request, response, clientRegistrationAuthenticationResult); + } + + private static OidcClientRegistration createClientRegistration() { + // @formatter:off + OidcClientRegistration expectedClientRegistrationResponse = OidcClientRegistration.builder() + .clientId("client-id") + .clientIdIssuedAt(Instant.now()) + .clientSecret("client-secret") + .clientName("client-name") + .redirectUri("https://client.example.com") + .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) + .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue()) + .responseType(OAuth2AuthorizationResponseType.CODE.getValue()) + .idTokenSignedResponseAlgorithm(SignatureAlgorithm.RS256.getName()) + .scope("scope1") + .scope("scope2") + .registrationClientUrl("https://auth-server:9000/connect/register?client_id=client-id") + .build(); + return expectedClientRegistrationResponse; + // @formatter:on + } + + @Test + public void doFilterWhenCustomAuthenticationConverterThenUsed() throws ServletException, IOException { + AuthenticationConverter authenticationConverter = mock(AuthenticationConverter.class); + this.filter.setAuthenticationConverter(authenticationConverter); + + String requestUri = DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + request.setParameter(OAuth2ParameterNames.CLIENT_ID, "client-id"); + + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(authenticationConverter).convert(request); + } + private OAuth2Error readError(MockHttpServletResponse response) throws Exception { MockClientHttpResponse httpResponse = new MockClientHttpResponse( response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus()));