Skip to content

Commit 5c77e01

Browse files
committed
Make OidcUserInfoEndpointFilter configurable
1 parent a9cf857 commit 5c77e01

File tree

2 files changed

+118
-10
lines changed

2 files changed

+118
-10
lines changed

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilter.java

+41-7
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.springframework.http.server.ServletServerHttpResponse;
2929
import org.springframework.security.authentication.AuthenticationManager;
3030
import org.springframework.security.core.Authentication;
31+
import org.springframework.security.core.AuthenticationException;
3132
import org.springframework.security.core.context.SecurityContextHolder;
3233
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
3334
import org.springframework.security.oauth2.core.OAuth2Error;
@@ -36,6 +37,8 @@
3637
import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
3738
import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken;
3839
import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcUserInfoHttpMessageConverter;
40+
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
41+
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
3942
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
4043
import org.springframework.security.web.util.matcher.OrRequestMatcher;
4144
import org.springframework.security.web.util.matcher.RequestMatcher;
@@ -66,6 +69,10 @@ public final class OidcUserInfoEndpointFilter extends OncePerRequestFilter {
6669
private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter =
6770
new OAuth2ErrorHttpMessageConverter();
6871

72+
private AuthenticationSuccessHandler authenticationSuccessHandler = this::onAuthenticationSuccess;
73+
74+
private AuthenticationFailureHandler authenticationFailureHandler = this::onAuthenticationFailure;
75+
6976
/**
7077
* Constructs an {@code OidcUserInfoEndpointFilter} using the provided parameters.
7178
*
@@ -107,27 +114,53 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
107114
OidcUserInfoAuthenticationToken userInfoAuthenticationResult =
108115
(OidcUserInfoAuthenticationToken) this.authenticationManager.authenticate(userInfoAuthentication);
109116

110-
sendUserInfoResponse(response, userInfoAuthenticationResult.getUserInfo());
111-
117+
this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, userInfoAuthenticationResult);
112118
} catch (OAuth2AuthenticationException ex) {
113-
sendErrorResponse(response, ex.getError());
119+
this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex);
114120
} catch (Exception ex) {
115121
OAuth2Error error = new OAuth2Error(
116122
OAuth2ErrorCodes.INVALID_REQUEST,
117123
"OpenID Connect 1.0 UserInfo Error: " + ex.getMessage(),
118124
"https://openid.net/specs/openid-connect-core-1_0.html#UserInfoError");
119-
sendErrorResponse(response, error);
125+
this.authenticationFailureHandler.onAuthenticationFailure(request, response,
126+
new OAuth2AuthenticationException(error));
120127
} finally {
121128
SecurityContextHolder.clearContext();
122129
}
123130
}
124131

125-
private void sendUserInfoResponse(HttpServletResponse response, OidcUserInfo userInfo) throws IOException {
132+
/**
133+
* Sets the {@link AuthenticationSuccessHandler} used for handling an {@link OidcUserInfoAuthenticationToken} and
134+
* returning the {@link OidcUserInfo OIDC User Info}.
135+
*
136+
* @param authenticationSuccessHandler the {@link AuthenticationSuccessHandler} for handling an {@link OidcUserInfoAuthenticationToken}
137+
*/
138+
public void setAuthenticationSuccessHandler(AuthenticationSuccessHandler authenticationSuccessHandler) {
139+
Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null");
140+
this.authenticationSuccessHandler = authenticationSuccessHandler;
141+
}
142+
143+
/**
144+
* Sets the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException}
145+
* and returning the {@link OAuth2Error Error Response}.
146+
*
147+
* @param authenticationFailureHandler the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException}
148+
*/
149+
public void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) {
150+
Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null");
151+
this.authenticationFailureHandler = authenticationFailureHandler;
152+
}
153+
154+
private void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response,
155+
Authentication authentication) throws IOException {
156+
OidcUserInfoAuthenticationToken userInfoAuthenticationToken = (OidcUserInfoAuthenticationToken) authentication;
126157
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
127-
this.userInfoHttpMessageConverter.write(userInfo, null, httpResponse);
158+
this.userInfoHttpMessageConverter.write(userInfoAuthenticationToken.getUserInfo(), null, httpResponse);
128159
}
129160

130-
private void sendErrorResponse(HttpServletResponse response, OAuth2Error error) throws IOException {
161+
private void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response,
162+
AuthenticationException authenticationException) throws IOException {
163+
OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError();
131164
HttpStatus httpStatus = HttpStatus.BAD_REQUEST;
132165
if (error.getErrorCode().equals(OAuth2ErrorCodes.INVALID_TOKEN)) {
133166
httpStatus = HttpStatus.UNAUTHORIZED;
@@ -138,4 +171,5 @@ private void sendErrorResponse(HttpServletResponse response, OAuth2Error error)
138171
httpResponse.setStatusCode(httpStatus);
139172
this.errorHttpResponseConverter.write(error, null, httpResponse);
140173
}
174+
141175
}

oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilterTests.java

+77-3
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
import org.springframework.security.oauth2.jwt.Jwt;
4545
import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken;
4646
import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken;
47+
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
48+
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
4749

4850
import static org.assertj.core.api.Assertions.assertThat;
4951
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
@@ -84,6 +86,20 @@ public void constructorWhenUserInfoEndpointUriIsEmptyThenThrowIllegalArgumentExc
8486
.withMessage("userInfoEndpointUri cannot be empty");
8587
}
8688

89+
@Test
90+
public void setAuthenticationSuccessHandlerNullThenThrowIllegalArgumentException() {
91+
assertThatIllegalArgumentException()
92+
.isThrownBy(() -> this.filter.setAuthenticationSuccessHandler(null))
93+
.withMessage("authenticationSuccessHandler cannot be null");
94+
}
95+
96+
@Test
97+
public void setAuthenticationFailureHandlerNullThenThrowIllegalArgumentException() {
98+
assertThatIllegalArgumentException()
99+
.isThrownBy(() -> this.filter.setAuthenticationFailureHandler(null))
100+
.withMessage("authenticationFailureHandler cannot be null");
101+
}
102+
87103
@Test
88104
public void doFilterWhenNotUserInfoRequestThenNotProcessed() throws Exception {
89105
String requestUri = "/path";
@@ -145,11 +161,21 @@ private void doFilterWhenUserInfoRequestThenSuccess(String httpMethod) throws Ex
145161

146162
@Test
147163
public void doFilterWhenUserInfoRequestInvalidTokenThenUnauthorizedError() throws Exception {
164+
doFilterWhenAuthenticationExceptionThenError(OAuth2ErrorCodes.INVALID_TOKEN, HttpStatus.UNAUTHORIZED);
165+
}
166+
167+
@Test
168+
public void doFilterWhenUserInfoRequestInsufficientScopeThenUnauthorizedError() throws Exception {
169+
doFilterWhenAuthenticationExceptionThenError(OAuth2ErrorCodes.INSUFFICIENT_SCOPE, HttpStatus.FORBIDDEN);
170+
}
171+
172+
private void doFilterWhenAuthenticationExceptionThenError(String oauth2ErrorCode, HttpStatus httpStatus)
173+
throws Exception {
148174
Authentication principal = new TestingAuthenticationToken("principal", "credentials");
149175
SecurityContextHolder.getContext().setAuthentication(principal);
150176

151177
when(this.authenticationManager.authenticate(any()))
152-
.thenThrow(new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN));
178+
.thenThrow(new OAuth2AuthenticationException(oauth2ErrorCode));
153179

154180
String requestUri = DEFAULT_OIDC_USER_INFO_ENDPOINT_URI;
155181
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
@@ -161,9 +187,57 @@ public void doFilterWhenUserInfoRequestInvalidTokenThenUnauthorizedError() throw
161187

162188
verifyNoInteractions(filterChain);
163189

164-
assertThat(response.getStatus()).isEqualTo(HttpStatus.UNAUTHORIZED.value());
190+
assertThat(response.getStatus()).isEqualTo(httpStatus.value());
165191
OAuth2Error error = readError(response);
166-
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_TOKEN);
192+
assertThat(error.getErrorCode()).isEqualTo(oauth2ErrorCode);
193+
}
194+
195+
@Test
196+
public void doFilterWhenCustomAuthenticationSuccessHandlerThenUses() throws Exception {
197+
AuthenticationSuccessHandler successHandler = mock(AuthenticationSuccessHandler.class);
198+
this.filter.setAuthenticationSuccessHandler(successHandler);
199+
200+
Authentication principal = new TestingAuthenticationToken("principal", "credentials");
201+
SecurityContextHolder.getContext().setAuthentication(principal);
202+
203+
OidcUserInfoAuthenticationToken authentication = new OidcUserInfoAuthenticationToken(principal, createUserInfo());
204+
when(this.authenticationManager.authenticate(any())).thenReturn(authentication);
205+
206+
String requestUri = DEFAULT_OIDC_USER_INFO_ENDPOINT_URI;
207+
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
208+
request.setServletPath(requestUri);
209+
MockHttpServletResponse response = new MockHttpServletResponse();
210+
FilterChain filterChain = mock(FilterChain.class);
211+
212+
this.filter.doFilter(request, response, filterChain);
213+
214+
verifyNoInteractions(filterChain);
215+
verify(successHandler).onAuthenticationSuccess(request, response, authentication);
216+
}
217+
218+
@Test
219+
public void doFilterWhenCustomFailureHandlerThenUses() throws Exception {
220+
AuthenticationFailureHandler failureHandler = mock(AuthenticationFailureHandler.class);
221+
this.filter.setAuthenticationFailureHandler(failureHandler);
222+
223+
Authentication principal = new TestingAuthenticationToken("principal", "credentials");
224+
SecurityContextHolder.getContext().setAuthentication(principal);
225+
226+
OAuth2AuthenticationException authenticationException =
227+
new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN);
228+
when(this.authenticationManager.authenticate(any())).thenThrow(authenticationException);
229+
230+
String requestUri = DEFAULT_OIDC_USER_INFO_ENDPOINT_URI;
231+
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
232+
request.setServletPath(requestUri);
233+
MockHttpServletResponse response = new MockHttpServletResponse();
234+
FilterChain filterChain = mock(FilterChain.class);
235+
236+
this.filter.doFilter(request, response, filterChain);
237+
238+
verifyNoInteractions(filterChain);
239+
240+
verify(failureHandler).onAuthenticationFailure(request, response, authenticationException);
167241
}
168242

169243
private OAuth2Error readError(MockHttpServletResponse response) throws Exception {

0 commit comments

Comments
 (0)