Skip to content

Commit 81a74bd

Browse files
committed
Client authentication with JWT assertion
Closes spring-projectsgh-59
1 parent 9053e31 commit 81a74bd

25 files changed

+1516
-35
lines changed

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/OidcClientMetadataClaimAccessor.java

+24
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,20 @@ default String getTokenEndpointAuthenticationMethod() {
9999
return getClaimAsString(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD);
100100
}
101101

102+
/**
103+
* Returns the {@link SignatureAlgorithm JWS} algorithm that must be used for signing the JWT used to authenticate
104+
* the Client at the Token Endpoint for the {@code private_key_jwt} and {@code client_secret_jwt} authentication
105+
* methods {@code (token_endpoint_auth_signing_alg)}
106+
*
107+
* @return the {@link SignatureAlgorithm JWS} algorithm that must be used for signing the JWT used to authenticate
108+
* the Client at the Token Endpoint for the {@code private_key_jwt} and {@code client_secret_jwt}
109+
* authentication methods {@code (token_endpoint_auth_signing_alg)}
110+
* @since 0.2.1
111+
*/
112+
default String getTokenEndpointAuthenticationSigningAlgorithm() {
113+
return getClaimAsString(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_SIGNING_ALG);
114+
}
115+
102116
/**
103117
* Returns the OAuth 2.0 {@code grant_type} values that the Client will restrict itself to using {@code (grant_types)}.
104118
*
@@ -155,4 +169,14 @@ default URL getRegistrationClientUrl() {
155169
return getClaimAsURL(OidcClientMetadataClaimNames.REGISTRATION_CLIENT_URI);
156170
}
157171

172+
/**
173+
* Returns {@code URL} for the Client's JSON Web Key Set {@code (jwks_uri)}
174+
*
175+
* @return {@code URL} for the Client's JSON Web Key Set {@code (jwks_uri)}
176+
* @since 0.2.1
177+
*/
178+
default URL getJwkSetUrl() {
179+
return getClaimAsURL(OidcClientMetadataClaimNames.JWKS_URI);
180+
}
181+
158182
}

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/OidcClientMetadataClaimNames.java

+14
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package org.springframework.security.oauth2.core.oidc;
1717

1818
import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
19+
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
1920

2021
/**
2122
* The names of the "claims" defined by OpenID Connect Dynamic Client Registration 1.0
@@ -95,4 +96,17 @@ public interface OidcClientMetadataClaimNames {
9596
*/
9697
String REGISTRATION_CLIENT_URI = "registration_client_uri";
9798

99+
/**
100+
* {@code jwks_uri} - {@code URL} for the Client's JSON Web Key Set
101+
* @since 0.2.1
102+
*/
103+
String JWKS_URI = "jwks_uri";
104+
105+
/**
106+
* {@code token_endpoint_auth_signing_alg} - {@link SignatureAlgorithm JWS} algorithm that must be used for signing
107+
* the JWT used to authenticate the Client at the Token Endpoint for the {@code private_key_jwt} and {@code client_secret_jwt}
108+
* authentication methods
109+
* @since 0.2.1
110+
*/
111+
String TOKEN_ENDPOINT_AUTH_SIGNING_ALG = "token_endpoint_auth_signing_alg";
98112
}

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/OidcClientRegistration.java

+24
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,20 @@ public Builder tokenEndpointAuthenticationMethod(String tokenEndpointAuthenticat
172172
return claim(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD, tokenEndpointAuthenticationMethod);
173173
}
174174

175+
/**
176+
* Sets the {@link SignatureAlgorithm JWS} algorithm that must be used for signing the JWT used to authenticate
177+
* the Client at the Token Endpoint for the {@code private_key_jwt} and {@code client_secret_jwt} authentication
178+
* methods
179+
* @param signingAlgorithm the {@link SignatureAlgorithm JWS} algorithm that must be used for signing
180+
* the JWT used to authenticate the Client at the Token Endpoint for the {@code private_key_jwt} and
181+
* {@code client_secret_jwt} authentication methods
182+
* @return the {@link Builder} for further configuration
183+
* @since 0.2.1
184+
*/
185+
public Builder tokenEndpointAuthenticationSigningAlgorithm(String signingAlgorithm) {
186+
return claim(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_SIGNING_ALG, signingAlgorithm);
187+
}
188+
175189
/**
176190
* Add the OAuth 2.0 {@code grant_type} that the Client will restrict itself to using, OPTIONAL.
177191
*
@@ -273,6 +287,16 @@ public Builder registrationClientUrl(String registrationClientUrl) {
273287
return claim(OidcClientMetadataClaimNames.REGISTRATION_CLIENT_URI, registrationClientUrl);
274288
}
275289

290+
/**
291+
* Sets {@code URL} for the Client's JSON Web Key Set {@code (jwks_uri)}
292+
* @param jwksSetUrl {@code URL} for the Client's JSON Web Key Set {@code (jwks_uri)}
293+
* @return the {@link Builder} for further configuration
294+
* @since 0.2.1
295+
*/
296+
public Builder jwkSetUrl(String jwksSetUrl) {
297+
return claim(OidcClientMetadataClaimNames.JWKS_URI, jwksSetUrl);
298+
}
299+
276300
/**
277301
* Sets the claim.
278302
*

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/http/converter/OidcClientRegistrationHttpMessageConverter.java

+2
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ private MapOidcClientRegistrationConverter() {
150150
claimConverters.put(OidcClientMetadataClaimNames.RESPONSE_TYPES, collectionStringConverter);
151151
claimConverters.put(OidcClientMetadataClaimNames.SCOPE, MapOidcClientRegistrationConverter::convertScope);
152152
claimConverters.put(OidcClientMetadataClaimNames.ID_TOKEN_SIGNED_RESPONSE_ALG, stringConverter);
153+
claimConverters.put(OidcClientMetadataClaimNames.JWKS_URI, stringConverter);
154+
claimConverters.put(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_SIGNING_ALG, stringConverter);
153155
this.claimTypeConverter = new ClaimTypeConverter(claimConverters);
154156
}
155157

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java

+73
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,33 @@
2020
import java.security.NoSuchAlgorithmException;
2121
import java.util.Base64;
2222
import java.util.Map;
23+
import java.util.Set;
2324

25+
import org.springframework.beans.factory.annotation.Autowired;
2426
import org.springframework.security.authentication.AuthenticationProvider;
2527
import org.springframework.security.core.Authentication;
2628
import org.springframework.security.core.AuthenticationException;
2729
import org.springframework.security.crypto.factory.PasswordEncoderFactories;
2830
import org.springframework.security.crypto.password.PasswordEncoder;
2931
import org.springframework.security.oauth2.core.AuthorizationGrantType;
32+
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
3033
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
3134
import org.springframework.security.oauth2.core.OAuth2Error;
3235
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
3336
import org.springframework.security.oauth2.core.OAuth2TokenType;
3437
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
3538
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
3639
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
40+
import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
41+
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
42+
import org.springframework.security.oauth2.jwt.JwtDecoder;
43+
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
44+
import org.springframework.security.oauth2.jwt.JwtException;
3745
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
3846
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
3947
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
4048
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
49+
import org.springframework.security.oauth2.server.authorization.config.ProviderSettings;
4150
import org.springframework.util.Assert;
4251
import org.springframework.util.StringUtils;
4352

@@ -56,9 +65,14 @@
5665
*/
5766
public final class OAuth2ClientAuthenticationProvider implements AuthenticationProvider {
5867
private static final String CLIENT_AUTHENTICATION_ERROR_URI = "https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-01#section-3.2.1";
68+
69+
private static final ClientAuthenticationMethod JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD =
70+
new ClientAuthenticationMethod("urn:ietf:params:oauth:client-assertion-type:jwt-bearer");
71+
5972
private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE);
6073
private final RegisteredClientRepository registeredClientRepository;
6174
private final OAuth2AuthorizationService authorizationService;
75+
private JwtDecoderFactory<RegisteredClient> decoderFactory;
6276
private PasswordEncoder passwordEncoder;
6377

6478
/**
@@ -89,11 +103,29 @@ public void setPasswordEncoder(PasswordEncoder passwordEncoder) {
89103
this.passwordEncoder = passwordEncoder;
90104
}
91105

106+
void setDecoderFactory(JwtDecoderFactory<RegisteredClient> decoderFactory) {
107+
this.decoderFactory = decoderFactory;
108+
}
109+
110+
@Autowired(required = false)
111+
protected void setProviderSettings(ProviderSettings providerSettings) {
112+
this.decoderFactory = new RegisteredClientJwtAssertionDecoderFactory(providerSettings);
113+
}
114+
92115
@Override
93116
public Authentication authenticate(Authentication authentication) throws AuthenticationException {
94117
OAuth2ClientAuthenticationToken clientAuthentication =
95118
(OAuth2ClientAuthenticationToken) authentication;
96119

120+
return JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD.equals(clientAuthentication.getClientAuthenticationMethod()) ?
121+
authenticateClientAssertion(authentication) :
122+
authenticationClientCredentials(authentication);
123+
}
124+
125+
private Authentication authenticationClientCredentials(Authentication authentication) throws AuthenticationException {
126+
OAuth2ClientAuthenticationToken clientAuthentication =
127+
(OAuth2ClientAuthenticationToken) authentication;
128+
97129
String clientId = clientAuthentication.getPrincipal().toString();
98130
RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
99131
if (registeredClient == null) {
@@ -125,6 +157,47 @@ public Authentication authenticate(Authentication authentication) throws Authent
125157
clientAuthentication.getClientAuthenticationMethod(), clientAuthentication.getCredentials());
126158
}
127159

160+
private Authentication authenticateClientAssertion(Authentication authentication) throws AuthenticationException {
161+
OAuth2ClientAuthenticationToken clientAuthentication =
162+
(OAuth2ClientAuthenticationToken) authentication;
163+
164+
String clientId = clientAuthentication.getPrincipal().toString();
165+
RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
166+
if (registeredClient == null) {
167+
throwInvalidClient(OAuth2ParameterNames.CLIENT_ID);
168+
}
169+
170+
Set<ClientAuthenticationMethod> allowedAuthenticationMethods = registeredClient.getClientAuthenticationMethods();
171+
172+
if (!allowedAuthenticationMethods.contains(ClientAuthenticationMethod.CLIENT_SECRET_JWT) &&
173+
!allowedAuthenticationMethods.contains(ClientAuthenticationMethod.PRIVATE_KEY_JWT)) {
174+
throwInvalidClient("authentication_method");
175+
}
176+
177+
boolean credentialsAuthenticated = false;
178+
179+
try {
180+
JwtDecoder jwtDecoder = this.decoderFactory.createDecoder(registeredClient);
181+
jwtDecoder.decode(clientAuthentication.getCredentials().toString());
182+
credentialsAuthenticated = true;
183+
} catch (JwtException e) {
184+
throwInvalidClient(OAuth2ParameterNames.CLIENT_ASSERTION);
185+
}
186+
187+
boolean pkceAuthenticated = authenticatePkceIfAvailable(clientAuthentication, registeredClient);
188+
credentialsAuthenticated = credentialsAuthenticated || pkceAuthenticated;
189+
if (!credentialsAuthenticated) {
190+
throwInvalidClient("credentials");
191+
}
192+
193+
JwsAlgorithm tokenEndpointSigningAlgorithm = registeredClient.getClientSettings().getTokenEndpointSigningAlgorithm();
194+
ClientAuthenticationMethod clientAuthentiationMethod = tokenEndpointSigningAlgorithm instanceof MacAlgorithm ?
195+
ClientAuthenticationMethod.CLIENT_SECRET_JWT : ClientAuthenticationMethod.PRIVATE_KEY_JWT;
196+
197+
return new OAuth2ClientAuthenticationToken(registeredClient,
198+
clientAuthentiationMethod, clientAuthentication.getCredentials());
199+
}
200+
128201
@Override
129202
public boolean supports(Class<?> authentication) {
130203
return OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/*
2+
* Copyright 2020-2021 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.security.oauth2.server.authorization.authentication;
17+
18+
import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator;
19+
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
20+
import org.springframework.security.oauth2.core.OAuth2Error;
21+
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
22+
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
23+
import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
24+
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
25+
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
26+
import org.springframework.security.oauth2.jwt.Jwt;
27+
import org.springframework.security.oauth2.jwt.JwtClaimValidator;
28+
import org.springframework.security.oauth2.jwt.JwtDecoder;
29+
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
30+
import org.springframework.security.oauth2.jwt.JwtTimestampValidator;
31+
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
32+
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
33+
import org.springframework.security.oauth2.server.authorization.config.ProviderSettings;
34+
import org.springframework.util.Assert;
35+
import org.springframework.util.StringUtils;
36+
37+
import javax.crypto.spec.SecretKeySpec;
38+
import java.nio.charset.StandardCharsets;
39+
import java.util.Collections;
40+
import java.util.HashMap;
41+
import java.util.List;
42+
import java.util.Map;
43+
import java.util.Objects;
44+
import java.util.concurrent.ConcurrentHashMap;
45+
import java.util.function.Function;
46+
47+
/**
48+
* Creates JWT decoders for registered clients
49+
*
50+
* @author Rafal Lewczuk
51+
* @since 0.2.1
52+
*/
53+
final class RegisteredClientJwtAssertionDecoderFactory implements JwtDecoderFactory<RegisteredClient> {
54+
55+
private static final String CLIENT_ASSERTION_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc7523#section-3";
56+
57+
private static final Map<JwsAlgorithm, String> JCA_ALGORITHM_MAPPINGS;
58+
59+
private final String tokenEndpointUri;
60+
61+
static {
62+
Map<JwsAlgorithm, String> mappings = new HashMap<>();
63+
mappings.put(MacAlgorithm.HS256, "HmacSHA256");
64+
mappings.put(MacAlgorithm.HS384, "HmacSHA384");
65+
mappings.put(MacAlgorithm.HS512, "HmacSHA512");
66+
JCA_ALGORITHM_MAPPINGS = Collections.unmodifiableMap(mappings);
67+
}
68+
69+
private final Function<RegisteredClient, JwsAlgorithm> jwsAlgorithmResolver =
70+
rc -> rc.getClientSettings().getTokenEndpointSigningAlgorithm();
71+
72+
private final Map<String, CachedJwtDecoder> cachedDecoders = new ConcurrentHashMap<>();
73+
74+
RegisteredClientJwtAssertionDecoderFactory(ProviderSettings providerSettings) {
75+
Assert.notNull(providerSettings, "providerSettings cannot be null");
76+
this.tokenEndpointUri = providerSettings.getIssuer() + providerSettings.getTokenEndpoint();
77+
}
78+
79+
@Override
80+
public JwtDecoder createDecoder(RegisteredClient registeredClient) {
81+
Assert.notNull(registeredClient, "registeredClient cannot be null");
82+
83+
CachedJwtDecoder cachedDecoder = this.cachedDecoders.get(registeredClient.getClientId());
84+
if (cachedDecoder != null &&
85+
Objects.equals(registeredClient.getClientSettings().getTokenEndpointSigningAlgorithm(),
86+
cachedDecoder.registeredClient.getClientSettings().getTokenEndpointSigningAlgorithm()) &&
87+
Objects.equals(cachedDecoder.registeredClient.getClientSecret(), registeredClient.getClientSecret()) &&
88+
Objects.equals(registeredClient.getClientSettings().getJwkSetUrl(),
89+
cachedDecoder.registeredClient.getClientSettings().getJwkSetUrl())) {
90+
return cachedDecoder.jwtDecoder;
91+
}
92+
93+
cachedDecoder = new CachedJwtDecoder(buildDecoder(registeredClient), registeredClient);
94+
cachedDecoder.jwtDecoder.setJwtValidator(createTokenValidator(registeredClient));
95+
this.cachedDecoders.put(registeredClient.getClientId(), cachedDecoder);
96+
return cachedDecoder.jwtDecoder;
97+
}
98+
99+
private NimbusJwtDecoder buildDecoder(RegisteredClient registeredClient) {
100+
JwsAlgorithm jwsAlgorithm = this.jwsAlgorithmResolver.apply(registeredClient);
101+
102+
if (jwsAlgorithm != null && SignatureAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
103+
String jwkSetUrl = registeredClient.getClientSettings().getJwkSetUrl();
104+
if (!StringUtils.hasText(jwkSetUrl)) {
105+
OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
106+
"misconfigured client", CLIENT_ASSERTION_ERROR_URI);
107+
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
108+
}
109+
return NimbusJwtDecoder.withJwkSetUri(jwkSetUrl).jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm).build();
110+
}
111+
112+
if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
113+
String clientSecret = registeredClient.getClientSecret();
114+
if (!StringUtils.hasText(clientSecret)) {
115+
OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
116+
"misconfigured client", CLIENT_ASSERTION_ERROR_URI);
117+
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
118+
}
119+
SecretKeySpec secretKeySpec = new SecretKeySpec(clientSecret.getBytes(StandardCharsets.UTF_8),
120+
JCA_ALGORITHM_MAPPINGS.get(jwsAlgorithm));
121+
return NimbusJwtDecoder.withSecretKey(secretKeySpec).macAlgorithm((MacAlgorithm) jwsAlgorithm).build();
122+
}
123+
124+
OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
125+
"misconfigured client", CLIENT_ASSERTION_ERROR_URI);
126+
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
127+
}
128+
129+
private OAuth2TokenValidator<Jwt> createTokenValidator(RegisteredClient registeredClient) {
130+
String clientId = registeredClient.getClientId();
131+
return new DelegatingOAuth2TokenValidator<>(
132+
new JwtClaimValidator<String>("iss", clientId::equals), // RFC 7523 section 3 (iss)
133+
new JwtClaimValidator<String>("sub", clientId::equals), // RFC 7523 section 3 (sub)
134+
new JwtClaimValidator<List<String>>("aud", l -> l.contains(tokenEndpointUri)), // RFC 7523 section 3 (aud)
135+
new JwtClaimValidator<>("exp", Objects::nonNull), // RFC 7523 section 3 (exp != null)
136+
new JwtTimestampValidator() // RFC 7523 section 3 (exp, nbf)
137+
);
138+
}
139+
140+
private static class CachedJwtDecoder {
141+
private final NimbusJwtDecoder jwtDecoder;
142+
private final RegisteredClient registeredClient;
143+
144+
CachedJwtDecoder(NimbusJwtDecoder jwtDecoder, RegisteredClient registeredClient) {
145+
this.jwtDecoder = jwtDecoder;
146+
this.registeredClient = registeredClient;
147+
}
148+
}
149+
}

0 commit comments

Comments
 (0)