26
26
import org .springframework .security .core .context .SecurityContext ;
27
27
import org .springframework .security .core .context .SecurityContextHolder ;
28
28
import org .springframework .security .oauth2 .core .AuthorizationGrantType ;
29
+ import org .springframework .security .oauth2 .core .OAuth2ErrorCodes ;
29
30
import org .springframework .security .oauth2 .core .endpoint .OAuth2AuthorizationRequest ;
30
31
import org .springframework .security .oauth2 .core .endpoint .OAuth2AuthorizationResponseType ;
31
32
import org .springframework .security .oauth2 .core .endpoint .OAuth2ParameterNames ;
41
42
import javax .servlet .http .HttpServletRequest ;
42
43
import javax .servlet .http .HttpServletResponse ;
43
44
import java .util .Set ;
45
+ import java .util .function .Consumer ;
44
46
45
47
import static org .assertj .core .api .Assertions .assertThat ;
46
48
import static org .assertj .core .api .Assertions .assertThatThrownBy ;
@@ -130,53 +132,29 @@ public void doFilterWhenAuthorizationRequestPostThenNotProcessed() throws Except
130
132
131
133
@ Test
132
134
public void doFilterWhenAuthorizationRequestMissingClientIdThenInvalidRequestError () throws Exception {
133
- RegisteredClient registeredClient = TestRegisteredClients .registeredClient ().build ();
134
-
135
- MockHttpServletRequest request = createAuthorizationRequest (registeredClient );
136
- request .removeParameter (OAuth2ParameterNames .CLIENT_ID );
137
- MockHttpServletResponse response = new MockHttpServletResponse ();
138
- FilterChain filterChain = mock (FilterChain .class );
139
-
140
- this .filter .doFilter (request , response , filterChain );
141
-
142
- verifyNoInteractions (filterChain );
143
-
144
- assertThat (response .getStatus ()).isEqualTo (HttpStatus .BAD_REQUEST .value ());
145
- assertThat (response .getErrorMessage ()).isEqualTo ("[invalid_request] OAuth 2.0 Parameter: client_id" );
135
+ doFilterWhenAuthorizationRequestInvalidParameterThenError (
136
+ TestRegisteredClients .registeredClient ().build (),
137
+ OAuth2ParameterNames .CLIENT_ID ,
138
+ OAuth2ErrorCodes .INVALID_REQUEST ,
139
+ request -> request .removeParameter (OAuth2ParameterNames .CLIENT_ID ));
146
140
}
147
141
148
142
@ Test
149
143
public void doFilterWhenAuthorizationRequestMultipleClientIdThenInvalidRequestError () throws Exception {
150
- RegisteredClient registeredClient = TestRegisteredClients .registeredClient ().build ();
151
-
152
- MockHttpServletRequest request = createAuthorizationRequest (registeredClient );
153
- request .addParameter (OAuth2ParameterNames .CLIENT_ID , registeredClient .getClientId ());
154
- MockHttpServletResponse response = new MockHttpServletResponse ();
155
- FilterChain filterChain = mock (FilterChain .class );
156
-
157
- this .filter .doFilter (request , response , filterChain );
158
-
159
- verifyNoInteractions (filterChain );
160
-
161
- assertThat (response .getStatus ()).isEqualTo (HttpStatus .BAD_REQUEST .value ());
162
- assertThat (response .getErrorMessage ()).isEqualTo ("[invalid_request] OAuth 2.0 Parameter: client_id" );
144
+ doFilterWhenAuthorizationRequestInvalidParameterThenError (
145
+ TestRegisteredClients .registeredClient ().build (),
146
+ OAuth2ParameterNames .CLIENT_ID ,
147
+ OAuth2ErrorCodes .INVALID_REQUEST ,
148
+ request -> request .addParameter (OAuth2ParameterNames .CLIENT_ID , "client-2" ));
163
149
}
164
150
165
151
@ Test
166
152
public void doFilterWhenAuthorizationRequestInvalidClientIdThenInvalidRequestError () throws Exception {
167
- RegisteredClient registeredClient = TestRegisteredClients .registeredClient ().build ();
168
-
169
- MockHttpServletRequest request = createAuthorizationRequest (registeredClient );
170
- request .setParameter (OAuth2ParameterNames .CLIENT_ID , "invalid" );
171
- MockHttpServletResponse response = new MockHttpServletResponse ();
172
- FilterChain filterChain = mock (FilterChain .class );
173
-
174
- this .filter .doFilter (request , response , filterChain );
175
-
176
- verifyNoInteractions (filterChain );
177
-
178
- assertThat (response .getStatus ()).isEqualTo (HttpStatus .BAD_REQUEST .value ());
179
- assertThat (response .getErrorMessage ()).isEqualTo ("[invalid_request] OAuth 2.0 Parameter: client_id" );
153
+ doFilterWhenAuthorizationRequestInvalidParameterThenError (
154
+ TestRegisteredClients .registeredClient ().build (),
155
+ OAuth2ParameterNames .CLIENT_ID ,
156
+ OAuth2ErrorCodes .INVALID_REQUEST ,
157
+ request -> request .setParameter (OAuth2ParameterNames .CLIENT_ID , "invalid" ));
180
158
}
181
159
182
160
@ Test
@@ -188,16 +166,10 @@ public void doFilterWhenAuthorizationRequestAndClientNotAuthorizedToRequestCodeT
188
166
when (this .registeredClientRepository .findByClientId ((eq (registeredClient .getClientId ()))))
189
167
.thenReturn (registeredClient );
190
168
191
- MockHttpServletRequest request = createAuthorizationRequest (registeredClient );
192
- MockHttpServletResponse response = new MockHttpServletResponse ();
193
- FilterChain filterChain = mock (FilterChain .class );
194
-
195
- this .filter .doFilter (request , response , filterChain );
196
-
197
- verifyNoInteractions (filterChain );
198
-
199
- assertThat (response .getStatus ()).isEqualTo (HttpStatus .BAD_REQUEST .value ());
200
- assertThat (response .getErrorMessage ()).isEqualTo ("[unauthorized_client] OAuth 2.0 Parameter: client_id" );
169
+ doFilterWhenAuthorizationRequestInvalidParameterThenError (
170
+ registeredClient ,
171
+ OAuth2ParameterNames .CLIENT_ID ,
172
+ OAuth2ErrorCodes .UNAUTHORIZED_CLIENT );
201
173
}
202
174
203
175
@ Test
@@ -206,17 +178,11 @@ public void doFilterWhenAuthorizationRequestInvalidRedirectUriThenInvalidRequest
206
178
when (this .registeredClientRepository .findByClientId ((eq (registeredClient .getClientId ()))))
207
179
.thenReturn (registeredClient );
208
180
209
- MockHttpServletRequest request = createAuthorizationRequest (registeredClient );
210
- request .setParameter (OAuth2ParameterNames .REDIRECT_URI , "https://invalid-example.com" );
211
- MockHttpServletResponse response = new MockHttpServletResponse ();
212
- FilterChain filterChain = mock (FilterChain .class );
213
-
214
- this .filter .doFilter (request , response , filterChain );
215
-
216
- verifyNoInteractions (filterChain );
217
-
218
- assertThat (response .getStatus ()).isEqualTo (HttpStatus .BAD_REQUEST .value ());
219
- assertThat (response .getErrorMessage ()).isEqualTo ("[invalid_request] OAuth 2.0 Parameter: redirect_uri" );
181
+ doFilterWhenAuthorizationRequestInvalidParameterThenError (
182
+ registeredClient ,
183
+ OAuth2ParameterNames .REDIRECT_URI ,
184
+ OAuth2ErrorCodes .INVALID_REQUEST ,
185
+ request -> request .setParameter (OAuth2ParameterNames .REDIRECT_URI , "https://invalid-example.com" ));
220
186
}
221
187
222
188
@ Test
@@ -225,17 +191,11 @@ public void doFilterWhenAuthorizationRequestMultipleRedirectUriThenInvalidReques
225
191
when (this .registeredClientRepository .findByClientId ((eq (registeredClient .getClientId ()))))
226
192
.thenReturn (registeredClient );
227
193
228
- MockHttpServletRequest request = createAuthorizationRequest (registeredClient );
229
- request .addParameter (OAuth2ParameterNames .REDIRECT_URI , "https://example2.com" );
230
- MockHttpServletResponse response = new MockHttpServletResponse ();
231
- FilterChain filterChain = mock (FilterChain .class );
232
-
233
- this .filter .doFilter (request , response , filterChain );
234
-
235
- verifyNoInteractions (filterChain );
236
-
237
- assertThat (response .getStatus ()).isEqualTo (HttpStatus .BAD_REQUEST .value ());
238
- assertThat (response .getErrorMessage ()).isEqualTo ("[invalid_request] OAuth 2.0 Parameter: redirect_uri" );
194
+ doFilterWhenAuthorizationRequestInvalidParameterThenError (
195
+ registeredClient ,
196
+ OAuth2ParameterNames .REDIRECT_URI ,
197
+ OAuth2ErrorCodes .INVALID_REQUEST ,
198
+ request -> request .addParameter (OAuth2ParameterNames .REDIRECT_URI , "https://example2.com" ));
239
199
}
240
200
241
201
@ Test
@@ -244,17 +204,11 @@ public void doFilterWhenAuthorizationRequestExcludesRedirectUriAndMultipleRegist
244
204
when (this .registeredClientRepository .findByClientId ((eq (registeredClient .getClientId ()))))
245
205
.thenReturn (registeredClient );
246
206
247
- MockHttpServletRequest request = createAuthorizationRequest (registeredClient );
248
- request .removeParameter (OAuth2ParameterNames .REDIRECT_URI );
249
- MockHttpServletResponse response = new MockHttpServletResponse ();
250
- FilterChain filterChain = mock (FilterChain .class );
251
-
252
- this .filter .doFilter (request , response , filterChain );
253
-
254
- verifyNoInteractions (filterChain );
255
-
256
- assertThat (response .getStatus ()).isEqualTo (HttpStatus .BAD_REQUEST .value ());
257
- assertThat (response .getErrorMessage ()).isEqualTo ("[invalid_request] OAuth 2.0 Parameter: redirect_uri" );
207
+ doFilterWhenAuthorizationRequestInvalidParameterThenError (
208
+ registeredClient ,
209
+ OAuth2ParameterNames .REDIRECT_URI ,
210
+ OAuth2ErrorCodes .INVALID_REQUEST ,
211
+ request -> request .removeParameter (OAuth2ParameterNames .REDIRECT_URI ));
258
212
}
259
213
260
214
@ Test
@@ -383,6 +337,27 @@ public void doFilterWhenAuthorizationRequestValidThenAuthorizationResponse() thr
383
337
assertThat (authorizationRequest .getAdditionalParameters ()).isEmpty ();
384
338
}
385
339
340
+ private void doFilterWhenAuthorizationRequestInvalidParameterThenError (RegisteredClient registeredClient ,
341
+ String parameterName , String errorCode ) throws Exception {
342
+ doFilterWhenAuthorizationRequestInvalidParameterThenError (registeredClient , parameterName , errorCode , request -> {});
343
+ }
344
+
345
+ private void doFilterWhenAuthorizationRequestInvalidParameterThenError (RegisteredClient registeredClient ,
346
+ String parameterName , String errorCode , Consumer <MockHttpServletRequest > requestConsumer ) throws Exception {
347
+
348
+ MockHttpServletRequest request = createAuthorizationRequest (registeredClient );
349
+ requestConsumer .accept (request );
350
+ MockHttpServletResponse response = new MockHttpServletResponse ();
351
+ FilterChain filterChain = mock (FilterChain .class );
352
+
353
+ this .filter .doFilter (request , response , filterChain );
354
+
355
+ verifyNoInteractions (filterChain );
356
+
357
+ assertThat (response .getStatus ()).isEqualTo (HttpStatus .BAD_REQUEST .value ());
358
+ assertThat (response .getErrorMessage ()).isEqualTo ("[" + errorCode + "] OAuth 2.0 Parameter: " + parameterName );
359
+ }
360
+
386
361
private static MockHttpServletRequest createAuthorizationRequest (RegisteredClient registeredClient ) {
387
362
String [] redirectUris = registeredClient .getRedirectUris ().toArray (new String [0 ]);
388
363
0 commit comments