44
44
import org .springframework .security .oauth2 .jwt .Jwt ;
45
45
import org .springframework .security .oauth2 .server .authorization .oidc .authentication .OidcUserInfoAuthenticationToken ;
46
46
import org .springframework .security .oauth2 .server .resource .authentication .JwtAuthenticationToken ;
47
+ import org .springframework .security .web .authentication .AuthenticationFailureHandler ;
48
+ import org .springframework .security .web .authentication .AuthenticationSuccessHandler ;
47
49
48
50
import static org .assertj .core .api .Assertions .assertThat ;
49
51
import static org .assertj .core .api .Assertions .assertThatIllegalArgumentException ;
@@ -84,6 +86,20 @@ public void constructorWhenUserInfoEndpointUriIsEmptyThenThrowIllegalArgumentExc
84
86
.withMessage ("userInfoEndpointUri cannot be empty" );
85
87
}
86
88
89
+ @ Test
90
+ public void setAuthenticationSuccessHandlerNullThenThrowIllegalArgumentException () {
91
+ assertThatIllegalArgumentException ()
92
+ .isThrownBy (() -> this .filter .setAuthenticationSuccessHandler (null ))
93
+ .withMessage ("authenticationSuccessHandler cannot be null" );
94
+ }
95
+
96
+ @ Test
97
+ public void setAuthenticationFailureHandlerNullThenThrowIllegalArgumentException () {
98
+ assertThatIllegalArgumentException ()
99
+ .isThrownBy (() -> this .filter .setAuthenticationFailureHandler (null ))
100
+ .withMessage ("authenticationFailureHandler cannot be null" );
101
+ }
102
+
87
103
@ Test
88
104
public void doFilterWhenNotUserInfoRequestThenNotProcessed () throws Exception {
89
105
String requestUri = "/path" ;
@@ -145,11 +161,21 @@ private void doFilterWhenUserInfoRequestThenSuccess(String httpMethod) throws Ex
145
161
146
162
@ Test
147
163
public void doFilterWhenUserInfoRequestInvalidTokenThenUnauthorizedError () throws Exception {
164
+ doFilterWhenAuthenticationExceptionThenError (OAuth2ErrorCodes .INVALID_TOKEN , HttpStatus .UNAUTHORIZED );
165
+ }
166
+
167
+ @ Test
168
+ public void doFilterWhenUserInfoRequestInsufficientScopeThenUnauthorizedError () throws Exception {
169
+ doFilterWhenAuthenticationExceptionThenError (OAuth2ErrorCodes .INSUFFICIENT_SCOPE , HttpStatus .FORBIDDEN );
170
+ }
171
+
172
+ private void doFilterWhenAuthenticationExceptionThenError (String oauth2ErrorCode , HttpStatus httpStatus )
173
+ throws Exception {
148
174
Authentication principal = new TestingAuthenticationToken ("principal" , "credentials" );
149
175
SecurityContextHolder .getContext ().setAuthentication (principal );
150
176
151
177
when (this .authenticationManager .authenticate (any ()))
152
- .thenThrow (new OAuth2AuthenticationException (OAuth2ErrorCodes . INVALID_TOKEN ));
178
+ .thenThrow (new OAuth2AuthenticationException (oauth2ErrorCode ));
153
179
154
180
String requestUri = DEFAULT_OIDC_USER_INFO_ENDPOINT_URI ;
155
181
MockHttpServletRequest request = new MockHttpServletRequest ("GET" , requestUri );
@@ -161,9 +187,57 @@ public void doFilterWhenUserInfoRequestInvalidTokenThenUnauthorizedError() throw
161
187
162
188
verifyNoInteractions (filterChain );
163
189
164
- assertThat (response .getStatus ()).isEqualTo (HttpStatus . UNAUTHORIZED .value ());
190
+ assertThat (response .getStatus ()).isEqualTo (httpStatus .value ());
165
191
OAuth2Error error = readError (response );
166
- assertThat (error .getErrorCode ()).isEqualTo (OAuth2ErrorCodes .INVALID_TOKEN );
192
+ assertThat (error .getErrorCode ()).isEqualTo (oauth2ErrorCode );
193
+ }
194
+
195
+ @ Test
196
+ public void doFilterWhenCustomAuthenticationSuccessHandlerThenUses () throws Exception {
197
+ AuthenticationSuccessHandler successHandler = mock (AuthenticationSuccessHandler .class );
198
+ this .filter .setAuthenticationSuccessHandler (successHandler );
199
+
200
+ Authentication principal = new TestingAuthenticationToken ("principal" , "credentials" );
201
+ SecurityContextHolder .getContext ().setAuthentication (principal );
202
+
203
+ OidcUserInfoAuthenticationToken authentication = new OidcUserInfoAuthenticationToken (principal , createUserInfo ());
204
+ when (this .authenticationManager .authenticate (any ())).thenReturn (authentication );
205
+
206
+ String requestUri = DEFAULT_OIDC_USER_INFO_ENDPOINT_URI ;
207
+ MockHttpServletRequest request = new MockHttpServletRequest ("GET" , requestUri );
208
+ request .setServletPath (requestUri );
209
+ MockHttpServletResponse response = new MockHttpServletResponse ();
210
+ FilterChain filterChain = mock (FilterChain .class );
211
+
212
+ this .filter .doFilter (request , response , filterChain );
213
+
214
+ verifyNoInteractions (filterChain );
215
+ verify (successHandler ).onAuthenticationSuccess (request , response , authentication );
216
+ }
217
+
218
+ @ Test
219
+ public void doFilterWhenCustomFailureHandlerThenUses () throws Exception {
220
+ AuthenticationFailureHandler failureHandler = mock (AuthenticationFailureHandler .class );
221
+ this .filter .setAuthenticationFailureHandler (failureHandler );
222
+
223
+ Authentication principal = new TestingAuthenticationToken ("principal" , "credentials" );
224
+ SecurityContextHolder .getContext ().setAuthentication (principal );
225
+
226
+ OAuth2AuthenticationException authenticationException =
227
+ new OAuth2AuthenticationException (OAuth2ErrorCodes .INVALID_TOKEN );
228
+ when (this .authenticationManager .authenticate (any ())).thenThrow (authenticationException );
229
+
230
+ String requestUri = DEFAULT_OIDC_USER_INFO_ENDPOINT_URI ;
231
+ MockHttpServletRequest request = new MockHttpServletRequest ("GET" , requestUri );
232
+ request .setServletPath (requestUri );
233
+ MockHttpServletResponse response = new MockHttpServletResponse ();
234
+ FilterChain filterChain = mock (FilterChain .class );
235
+
236
+ this .filter .doFilter (request , response , filterChain );
237
+
238
+ verifyNoInteractions (filterChain );
239
+
240
+ verify (failureHandler ).onAuthenticationFailure (request , response , authenticationException );
167
241
}
168
242
169
243
private OAuth2Error readError (MockHttpServletResponse response ) throws Exception {
0 commit comments