19
19
import java .security .MessageDigest ;
20
20
import java .security .NoSuchAlgorithmException ;
21
21
import java .util .Base64 ;
22
+ import java .util .Collections ;
23
+ import java .util .HashMap ;
24
+ import java .util .List ;
22
25
import java .util .Map ;
26
+ import java .util .Objects ;
27
+ import java .util .Set ;
28
+ import java .util .concurrent .ConcurrentHashMap ;
29
+ import java .util .function .Function ;
23
30
31
+ import org .springframework .beans .factory .annotation .Autowired ;
24
32
import org .springframework .security .authentication .AuthenticationProvider ;
25
33
import org .springframework .security .core .Authentication ;
26
34
import org .springframework .security .core .AuthenticationException ;
27
35
import org .springframework .security .crypto .factory .PasswordEncoderFactories ;
28
36
import org .springframework .security .crypto .password .PasswordEncoder ;
29
37
import org .springframework .security .oauth2 .core .AuthorizationGrantType ;
38
+ import org .springframework .security .oauth2 .core .ClientAuthenticationMethod ;
39
+ import org .springframework .security .oauth2 .core .DelegatingOAuth2TokenValidator ;
30
40
import org .springframework .security .oauth2 .core .OAuth2AuthenticationException ;
31
41
import org .springframework .security .oauth2 .core .OAuth2Error ;
32
42
import org .springframework .security .oauth2 .core .OAuth2ErrorCodes ;
33
43
import org .springframework .security .oauth2 .core .OAuth2TokenType ;
44
+ import org .springframework .security .oauth2 .core .OAuth2TokenValidator ;
34
45
import org .springframework .security .oauth2 .core .endpoint .OAuth2AuthorizationRequest ;
35
46
import org .springframework .security .oauth2 .core .endpoint .OAuth2ParameterNames ;
36
47
import org .springframework .security .oauth2 .core .endpoint .PkceParameterNames ;
48
+ import org .springframework .security .oauth2 .jose .jws .JwsAlgorithm ;
49
+ import org .springframework .security .oauth2 .jose .jws .MacAlgorithm ;
50
+ import org .springframework .security .oauth2 .jose .jws .SignatureAlgorithm ;
51
+ import org .springframework .security .oauth2 .jwt .Jwt ;
52
+ import org .springframework .security .oauth2 .jwt .JwtClaimValidator ;
53
+ import org .springframework .security .oauth2 .jwt .JwtDecoder ;
54
+ import org .springframework .security .oauth2 .jwt .JwtDecoderFactory ;
55
+ import org .springframework .security .oauth2 .jwt .JwtException ;
56
+ import org .springframework .security .oauth2 .jwt .JwtTimestampValidator ;
57
+ import org .springframework .security .oauth2 .jwt .NimbusJwtDecoder ;
37
58
import org .springframework .security .oauth2 .server .authorization .OAuth2Authorization ;
38
59
import org .springframework .security .oauth2 .server .authorization .OAuth2AuthorizationService ;
39
60
import org .springframework .security .oauth2 .server .authorization .client .RegisteredClient ;
40
61
import org .springframework .security .oauth2 .server .authorization .client .RegisteredClientRepository ;
62
+ import org .springframework .security .oauth2 .server .authorization .config .ProviderSettings ;
41
63
import org .springframework .util .Assert ;
42
64
import org .springframework .util .StringUtils ;
43
65
66
+ import javax .crypto .spec .SecretKeySpec ;
67
+
44
68
/**
45
69
* An {@link AuthenticationProvider} implementation used for authenticating an OAuth 2.0 Client.
46
70
*
47
71
* @author Joe Grandja
48
72
* @author Patryk Kostrzewa
49
73
* @author Daniel Garnier-Moiroux
74
+ * @author Rafal Lewczuk
50
75
* @since 0.0.1
51
76
* @see AuthenticationProvider
52
77
* @see OAuth2ClientAuthenticationToken
56
81
*/
57
82
public final class OAuth2ClientAuthenticationProvider implements AuthenticationProvider {
58
83
private static final String CLIENT_AUTHENTICATION_ERROR_URI = "https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-01#section-3.2.1" ;
84
+
85
+ private static final ClientAuthenticationMethod JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD =
86
+ new ClientAuthenticationMethod ("urn:ietf:params:oauth:client-assertion-type:jwt-bearer" );
87
+
59
88
private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType (OAuth2ParameterNames .CODE );
60
89
private final RegisteredClientRepository registeredClientRepository ;
61
90
private final OAuth2AuthorizationService authorizationService ;
91
+ private JwtDecoderFactory <RegisteredClient > jwtDecoderFactory ;
92
+ private ProviderSettings providerSettings ;
62
93
private PasswordEncoder passwordEncoder ;
63
94
64
95
/**
@@ -74,6 +105,7 @@ public OAuth2ClientAuthenticationProvider(RegisteredClientRepository registeredC
74
105
this .registeredClientRepository = registeredClientRepository ;
75
106
this .authorizationService = authorizationService ;
76
107
this .passwordEncoder = PasswordEncoderFactories .createDelegatingPasswordEncoder ();
108
+ this .jwtDecoderFactory = new RegisteredClientJwtAssertionDecoderFactory ();
77
109
}
78
110
79
111
/**
@@ -89,11 +121,25 @@ public void setPasswordEncoder(PasswordEncoder passwordEncoder) {
89
121
this .passwordEncoder = passwordEncoder ;
90
122
}
91
123
124
+ @ Autowired
125
+ protected void setProviderSettings (ProviderSettings providerSettings ) {
126
+ this .providerSettings = providerSettings ;
127
+ }
128
+
92
129
@ Override
93
130
public Authentication authenticate (Authentication authentication ) throws AuthenticationException {
94
131
OAuth2ClientAuthenticationToken clientAuthentication =
95
132
(OAuth2ClientAuthenticationToken ) authentication ;
96
133
134
+ return JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD .equals (clientAuthentication .getClientAuthenticationMethod ()) ?
135
+ authenticateClientAssertion (authentication ) :
136
+ authenticationClientCredentials (authentication );
137
+ }
138
+
139
+ private Authentication authenticationClientCredentials (Authentication authentication ) throws AuthenticationException {
140
+ OAuth2ClientAuthenticationToken clientAuthentication =
141
+ (OAuth2ClientAuthenticationToken ) authentication ;
142
+
97
143
String clientId = clientAuthentication .getPrincipal ().toString ();
98
144
RegisteredClient registeredClient = this .registeredClientRepository .findByClientId (clientId );
99
145
if (registeredClient == null ) {
@@ -125,6 +171,64 @@ public Authentication authenticate(Authentication authentication) throws Authent
125
171
clientAuthentication .getClientAuthenticationMethod (), clientAuthentication .getCredentials ());
126
172
}
127
173
174
+ private Authentication authenticateClientAssertion (Authentication authentication ) throws AuthenticationException {
175
+ OAuth2ClientAuthenticationToken clientAuthentication =
176
+ (OAuth2ClientAuthenticationToken ) authentication ;
177
+
178
+ String clientId = clientAuthentication .getPrincipal ().toString ();
179
+ RegisteredClient registeredClient = this .registeredClientRepository .findByClientId (clientId );
180
+ if (registeredClient == null ) {
181
+ throwInvalidClient (OAuth2ParameterNames .CLIENT_ID );
182
+ }
183
+
184
+ Set <ClientAuthenticationMethod > allowedAuthenticationMethods = registeredClient .getClientAuthenticationMethods ();
185
+
186
+ if (!allowedAuthenticationMethods .contains (ClientAuthenticationMethod .CLIENT_SECRET_JWT ) &&
187
+ !allowedAuthenticationMethods .contains (ClientAuthenticationMethod .PRIVATE_KEY_JWT )) {
188
+ throwInvalidClient ("authentication_method" );
189
+ }
190
+
191
+ boolean credentialsAuthenticated = false ;
192
+
193
+ try {
194
+ JwtDecoder jwtDecoder = this .jwtDecoderFactory .createDecoder (registeredClient );
195
+ Jwt jwt = jwtDecoder .decode (clientAuthentication .getCredentials ().toString ());
196
+ List <String > aud = jwt .getClaimAsStringList ("aud" );
197
+ String issuer = getIssuerUri (clientAuthentication .getRequestUri ());
198
+ if (aud == null || !aud .contains (issuer )) {
199
+ throwInvalidClient (OAuth2ParameterNames .CLIENT_ASSERTION );
200
+ }
201
+ credentialsAuthenticated = true ;
202
+ } catch (JwtException e ) {
203
+ throwInvalidClient (OAuth2ParameterNames .CLIENT_ASSERTION );
204
+ }
205
+
206
+ boolean pkceAuthenticated = authenticatePkceIfAvailable (clientAuthentication , registeredClient );
207
+ credentialsAuthenticated = credentialsAuthenticated || pkceAuthenticated ;
208
+ if (!credentialsAuthenticated ) {
209
+ throwInvalidClient ("credentials" );
210
+ }
211
+
212
+ JwsAlgorithm tokenEndpointSigningAlgorithm = registeredClient .getClientSettings ().getTokenEndpointSigningAlgorithm ();
213
+ ClientAuthenticationMethod clientAuthentiationMethod = tokenEndpointSigningAlgorithm instanceof MacAlgorithm ?
214
+ ClientAuthenticationMethod .CLIENT_SECRET_JWT : ClientAuthenticationMethod .PRIVATE_KEY_JWT ;
215
+
216
+ return new OAuth2ClientAuthenticationToken (registeredClient ,
217
+ clientAuthentiationMethod , clientAuthentication .getCredentials ());
218
+ }
219
+
220
+ private String getIssuerUri (String requestUri ) throws AuthenticationException {
221
+ if (requestUri .endsWith (providerSettings .getTokenEndpoint ())) {
222
+ return providerSettings .getIssuer () + providerSettings .getTokenEndpoint ();
223
+ } else if (requestUri .endsWith (providerSettings .getTokenIntrospectionEndpoint ())) {
224
+ return providerSettings .getIssuer () + providerSettings .getTokenIntrospectionEndpoint ();
225
+ } else if (requestUri .endsWith (providerSettings .getTokenRevocationEndpoint ())) {
226
+ return providerSettings .getIssuer () + providerSettings .getTokenRevocationEndpoint ();
227
+ }
228
+ throwInvalidClient (OAuth2ParameterNames .CLIENT_ASSERTION );
229
+ return null ;
230
+ }
231
+
128
232
@ Override
129
233
public boolean supports (Class <?> authentication ) {
130
234
return OAuth2ClientAuthenticationToken .class .isAssignableFrom (authentication );
@@ -201,4 +305,92 @@ private static void throwInvalidClient(String parameterName) {
201
305
throw new OAuth2AuthenticationException (error );
202
306
}
203
307
308
+ private static class CachedJwtDecoder {
309
+ private final NimbusJwtDecoder jwtDecoder ;
310
+ private final RegisteredClient registeredClient ;
311
+
312
+ CachedJwtDecoder (NimbusJwtDecoder jwtDecoder , RegisteredClient registeredClient ) {
313
+ this .jwtDecoder = jwtDecoder ;
314
+ this .registeredClient = registeredClient ;
315
+ }
316
+ }
317
+
318
+ private static class RegisteredClientJwtAssertionDecoderFactory implements JwtDecoderFactory <RegisteredClient > {
319
+
320
+ private static final String CLIENT_ASSERTION_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc7523#section-3" ;
321
+
322
+ private static final Map <JwsAlgorithm , String > JCA_ALGORITHM_MAPPINGS ;
323
+
324
+ static {
325
+ Map <JwsAlgorithm , String > mappings = new HashMap <>();
326
+ mappings .put (MacAlgorithm .HS256 , "HmacSHA256" );
327
+ mappings .put (MacAlgorithm .HS384 , "HmacSHA384" );
328
+ mappings .put (MacAlgorithm .HS512 , "HmacSHA512" );
329
+ JCA_ALGORITHM_MAPPINGS = Collections .unmodifiableMap (mappings );
330
+ }
331
+
332
+ private final Function <RegisteredClient , JwsAlgorithm > jwsAlgorithmResolver =
333
+ rc -> rc .getClientSettings ().getTokenEndpointSigningAlgorithm ();
334
+
335
+ private final Map <String , CachedJwtDecoder > cachedDecoders = new ConcurrentHashMap <>();
336
+
337
+ @ Override
338
+ public JwtDecoder createDecoder (RegisteredClient registeredClient ) {
339
+ Assert .notNull (registeredClient , "registeredClient cannot be null" );
340
+
341
+ CachedJwtDecoder cachedDecoder = this .cachedDecoders .get (registeredClient .getClientId ());
342
+ if (cachedDecoder != null && registeredClient .equals (cachedDecoder .registeredClient )) {
343
+ return cachedDecoder .jwtDecoder ;
344
+ }
345
+
346
+ cachedDecoder = new CachedJwtDecoder (buildDecoder (registeredClient ), registeredClient );
347
+ cachedDecoder .jwtDecoder .setJwtValidator (createTokenValidator (registeredClient ));
348
+ this .cachedDecoders .put (registeredClient .getClientId (), cachedDecoder );
349
+ return cachedDecoder .jwtDecoder ;
350
+ }
351
+
352
+ private NimbusJwtDecoder buildDecoder (RegisteredClient registeredClient ) {
353
+ JwsAlgorithm jwsAlgorithm = this .jwsAlgorithmResolver .apply (registeredClient );
354
+
355
+ if (jwsAlgorithm != null && SignatureAlgorithm .class .isAssignableFrom (jwsAlgorithm .getClass ())) {
356
+ String jwkSetUrl = registeredClient .getClientSettings ().getJwkSetUrl ();
357
+ if (!StringUtils .hasText (jwkSetUrl )) {
358
+ OAuth2Error oauth2Error = new OAuth2Error (OAuth2ErrorCodes .INVALID_CLIENT ,
359
+ "misconfigured client" , CLIENT_ASSERTION_ERROR_URI );
360
+ throw new OAuth2AuthenticationException (oauth2Error , oauth2Error .toString ());
361
+ }
362
+ return NimbusJwtDecoder .withJwkSetUri (jwkSetUrl ).jwsAlgorithm ((SignatureAlgorithm ) jwsAlgorithm ).build ();
363
+ }
364
+
365
+ if (jwsAlgorithm != null && MacAlgorithm .class .isAssignableFrom (jwsAlgorithm .getClass ())) {
366
+ String clientSecret = registeredClient .getClientSecret ();
367
+ if (!StringUtils .hasText (clientSecret )) {
368
+ OAuth2Error oauth2Error = new OAuth2Error (OAuth2ErrorCodes .INVALID_CLIENT ,
369
+ "misconfigured client" , CLIENT_ASSERTION_ERROR_URI );
370
+ throw new OAuth2AuthenticationException (oauth2Error , oauth2Error .toString ());
371
+ }
372
+ SecretKeySpec secretKeySpec = new SecretKeySpec (clientSecret .getBytes (StandardCharsets .UTF_8 ),
373
+ JCA_ALGORITHM_MAPPINGS .get (jwsAlgorithm ));
374
+ return NimbusJwtDecoder .withSecretKey (secretKeySpec ).macAlgorithm ((MacAlgorithm ) jwsAlgorithm ).build ();
375
+ }
376
+
377
+ OAuth2Error oauth2Error = new OAuth2Error (OAuth2ErrorCodes .INVALID_CLIENT ,
378
+ "misconfigured client" , CLIENT_ASSERTION_ERROR_URI );
379
+ throw new OAuth2AuthenticationException (oauth2Error , oauth2Error .toString ());
380
+ }
381
+
382
+ private OAuth2TokenValidator <Jwt > createTokenValidator (RegisteredClient registeredClient ) {
383
+ String clientId = registeredClient .getClientId ();
384
+ return new DelegatingOAuth2TokenValidator <>(
385
+ new JwtClaimValidator <String >("iss" , clientId ::equals ), // RFC 7523 section 3 (iss)
386
+ new JwtClaimValidator <String >("sub" , clientId ::equals ), // RFC 7523 section 3 (sub)
387
+ new JwtClaimValidator <>("exp" , Objects ::nonNull ), // RFC 7523 section 3 (exp != null)
388
+ new JwtTimestampValidator () // RFC 7523 section 3 (exp, nbf)
389
+ );
390
+ // The `aud` claim is not verified here
391
+
392
+ // TODO RFC 7523 section 3 #7: JWT may contain "jti" claim that provides unique identified for the token (OPTIONAL)
393
+ }
394
+ }
395
+
204
396
}
0 commit comments