Skip to content

Commit 92de71d

Browse files
committed
Client authentication with JWT assertion
Closes spring-projectsgh-59
1 parent 9053e31 commit 92de71d

26 files changed

+1596
-45
lines changed

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/OidcClientMetadataClaimAccessor.java

+24
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,20 @@ default String getTokenEndpointAuthenticationMethod() {
9999
return getClaimAsString(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD);
100100
}
101101

102+
/**
103+
* Returns the {@link SignatureAlgorithm JWS} algorithm that must be used for signing the JWT used to authenticate
104+
* the Client at the Token Endpoint for the {@code private_key_jwt} and {@code client_secret_jwt} authentication
105+
* methods {@code (token_endpoint_auth_signing_alg)}
106+
*
107+
* @return the {@link SignatureAlgorithm JWS} algorithm that must be used for signing the JWT used to authenticate
108+
* the Client at the Token Endpoint for the {@code private_key_jwt} and {@code client_secret_jwt}
109+
* authentication methods {@code (token_endpoint_auth_signing_alg)}
110+
* @since 0.2.1
111+
*/
112+
default String getTokenEndpointAuthenticationSigningAlgorithm() {
113+
return getClaimAsString(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_SIGNING_ALG);
114+
}
115+
102116
/**
103117
* Returns the OAuth 2.0 {@code grant_type} values that the Client will restrict itself to using {@code (grant_types)}.
104118
*
@@ -155,4 +169,14 @@ default URL getRegistrationClientUrl() {
155169
return getClaimAsURL(OidcClientMetadataClaimNames.REGISTRATION_CLIENT_URI);
156170
}
157171

172+
/**
173+
* Returns {@code URL} for the Client's JSON Web Key Set {@code (jwks_uri)}
174+
*
175+
* @return {@code URL} for the Client's JSON Web Key Set {@code (jwks_uri)}
176+
* @since 0.2.1
177+
*/
178+
default URL getJwkSetUrl() {
179+
return getClaimAsURL(OidcClientMetadataClaimNames.JWKS_URI);
180+
}
181+
158182
}

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/OidcClientMetadataClaimNames.java

+14
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package org.springframework.security.oauth2.core.oidc;
1717

1818
import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
19+
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
1920

2021
/**
2122
* The names of the "claims" defined by OpenID Connect Dynamic Client Registration 1.0
@@ -95,4 +96,17 @@ public interface OidcClientMetadataClaimNames {
9596
*/
9697
String REGISTRATION_CLIENT_URI = "registration_client_uri";
9798

99+
/**
100+
* {@code jwks_uri} - {@code URL} for the Client's JSON Web Key Set
101+
* @since 0.2.1
102+
*/
103+
String JWKS_URI = "jwks_uri";
104+
105+
/**
106+
* {@code token_endpoint_auth_signing_alg} - {@link SignatureAlgorithm JWS} algorithm that must be used for signing
107+
* the JWT used to authenticate the Client at the Token Endpoint for the {@code private_key_jwt} and {@code client_secret_jwt}
108+
* authentication methods
109+
* @since 0.2.1
110+
*/
111+
String TOKEN_ENDPOINT_AUTH_SIGNING_ALG = "token_endpoint_auth_signing_alg";
98112
}

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/OidcClientRegistration.java

+24
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,20 @@ public Builder tokenEndpointAuthenticationMethod(String tokenEndpointAuthenticat
172172
return claim(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD, tokenEndpointAuthenticationMethod);
173173
}
174174

175+
/**
176+
* Sets the {@link SignatureAlgorithm JWS} algorithm that must be used for signing the JWT used to authenticate
177+
* the Client at the Token Endpoint for the {@code private_key_jwt} and {@code client_secret_jwt} authentication
178+
* methods
179+
* @param signingAlgorithm the {@link SignatureAlgorithm JWS} algorithm that must be used for signing
180+
* the JWT used to authenticate the Client at the Token Endpoint for the {@code private_key_jwt} and
181+
* {@code client_secret_jwt} authentication methods
182+
* @return the {@link Builder} for further configuration
183+
* @since 0.2.1
184+
*/
185+
public Builder tokenEndpointAuthenticationSigningAlgorithm(String signingAlgorithm) {
186+
return claim(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_SIGNING_ALG, signingAlgorithm);
187+
}
188+
175189
/**
176190
* Add the OAuth 2.0 {@code grant_type} that the Client will restrict itself to using, OPTIONAL.
177191
*
@@ -273,6 +287,16 @@ public Builder registrationClientUrl(String registrationClientUrl) {
273287
return claim(OidcClientMetadataClaimNames.REGISTRATION_CLIENT_URI, registrationClientUrl);
274288
}
275289

290+
/**
291+
* Sets {@code URL} for the Client's JSON Web Key Set {@code (jwks_uri)}
292+
* @param jwksSetUrl {@code URL} for the Client's JSON Web Key Set {@code (jwks_uri)}
293+
* @return the {@link Builder} for further configuration
294+
* @since 0.2.1
295+
*/
296+
public Builder jwkSetUrl(String jwksSetUrl) {
297+
return claim(OidcClientMetadataClaimNames.JWKS_URI, jwksSetUrl);
298+
}
299+
276300
/**
277301
* Sets the claim.
278302
*

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java

+52-9
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,31 @@
2020
import java.security.NoSuchAlgorithmException;
2121
import java.util.Base64;
2222
import java.util.Map;
23+
import java.util.Set;
2324

25+
import org.springframework.beans.factory.annotation.Autowired;
2426
import org.springframework.security.authentication.AuthenticationProvider;
2527
import org.springframework.security.core.Authentication;
2628
import org.springframework.security.core.AuthenticationException;
2729
import org.springframework.security.crypto.factory.PasswordEncoderFactories;
2830
import org.springframework.security.crypto.password.PasswordEncoder;
2931
import org.springframework.security.oauth2.core.AuthorizationGrantType;
32+
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
3033
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
3134
import org.springframework.security.oauth2.core.OAuth2Error;
3235
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
3336
import org.springframework.security.oauth2.core.OAuth2TokenType;
3437
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
3538
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
3639
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
40+
import org.springframework.security.oauth2.jwt.JwtDecoder;
41+
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
42+
import org.springframework.security.oauth2.jwt.JwtException;
3743
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
3844
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
3945
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
4046
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
47+
import org.springframework.security.oauth2.server.authorization.config.ProviderSettings;
4148
import org.springframework.util.Assert;
4249
import org.springframework.util.StringUtils;
4350

@@ -56,9 +63,14 @@
5663
*/
5764
public final class OAuth2ClientAuthenticationProvider implements AuthenticationProvider {
5865
private static final String CLIENT_AUTHENTICATION_ERROR_URI = "https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-01#section-3.2.1";
66+
67+
static final ClientAuthenticationMethod CLIENT_ASSERTION_AUTHENTICATION_METHOD =
68+
new ClientAuthenticationMethod("urn:ietf:params:oauth:client-assertion-type:jwt-bearer");
69+
5970
private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE);
6071
private final RegisteredClientRepository registeredClientRepository;
6172
private final OAuth2AuthorizationService authorizationService;
73+
private JwtDecoderFactory<RegisteredClient> decoderFactory;
6274
private PasswordEncoder passwordEncoder;
6375

6476
/**
@@ -89,6 +101,15 @@ public void setPasswordEncoder(PasswordEncoder passwordEncoder) {
89101
this.passwordEncoder = passwordEncoder;
90102
}
91103

104+
void setDecoderFactory(JwtDecoderFactory<RegisteredClient> decoderFactory) {
105+
this.decoderFactory = decoderFactory;
106+
}
107+
108+
@Autowired(required = false)
109+
protected void setProviderSettings(ProviderSettings providerSettings) {
110+
this.decoderFactory = new RegisteredClientJwtAssertionDecoderFactory(providerSettings);
111+
}
112+
92113
@Override
93114
public Authentication authenticate(Authentication authentication) throws AuthenticationException {
94115
OAuth2ClientAuthenticationToken clientAuthentication =
@@ -100,19 +121,41 @@ public Authentication authenticate(Authentication authentication) throws Authent
100121
throwInvalidClient(OAuth2ParameterNames.CLIENT_ID);
101122
}
102123

103-
if (!registeredClient.getClientAuthenticationMethods().contains(
104-
clientAuthentication.getClientAuthenticationMethod())) {
124+
ClientAuthenticationMethod clientAuthentiationMethod = clientAuthentication.getClientAuthenticationMethod();
125+
Set<ClientAuthenticationMethod> allowedAuthenticationMethods = registeredClient.getClientAuthenticationMethods();
126+
127+
if (!allowedAuthenticationMethods.contains(clientAuthentiationMethod)
128+
&& !(CLIENT_ASSERTION_AUTHENTICATION_METHOD.equals(clientAuthentiationMethod)
129+
&& (allowedAuthenticationMethods.contains(ClientAuthenticationMethod.CLIENT_SECRET_JWT)
130+
|| allowedAuthenticationMethods.contains(ClientAuthenticationMethod.PRIVATE_KEY_JWT)))) {
131+
throwInvalidClient("authentication_method");
132+
}
133+
134+
if (CLIENT_ASSERTION_AUTHENTICATION_METHOD.equals(clientAuthentiationMethod) && this.decoderFactory == null) {
135+
// Decoder not present -> authentication method not supported
105136
throwInvalidClient("authentication_method");
106137
}
107138

108139
boolean credentialsAuthenticated = false;
109140

110-
if (clientAuthentication.getCredentials() != null) {
111-
String clientSecret = clientAuthentication.getCredentials().toString();
112-
if (!this.passwordEncoder.matches(clientSecret, registeredClient.getClientSecret())) {
113-
throwInvalidClient(OAuth2ParameterNames.CLIENT_SECRET);
141+
if (CLIENT_ASSERTION_AUTHENTICATION_METHOD.equals(clientAuthentiationMethod)) {
142+
// PRIVATE_KEY_JWT or CLIENT_SECRET_JWT authentication methods
143+
try {
144+
JwtDecoder jwtDecoder = this.decoderFactory.createDecoder(registeredClient);
145+
jwtDecoder.decode(clientAuthentication.getCredentials().toString());
146+
credentialsAuthenticated = true;
147+
} catch (JwtException e) {
148+
throwInvalidClient(OAuth2ParameterNames.CLIENT_ASSERTION);
149+
}
150+
} else {
151+
// CLIENT_SECRET_BASIC or CLIENT_SECRET_POST authentication methods
152+
if (clientAuthentication.getCredentials() != null) {
153+
String clientSecret = clientAuthentication.getCredentials().toString();
154+
if (!this.passwordEncoder.matches(clientSecret, registeredClient.getClientSecret())) {
155+
throwInvalidClient(OAuth2ParameterNames.CLIENT_SECRET);
156+
}
157+
credentialsAuthenticated = true;
114158
}
115-
credentialsAuthenticated = true;
116159
}
117160

118161
boolean pkceAuthenticated = authenticatePkceIfAvailable(clientAuthentication, registeredClient);
@@ -122,7 +165,7 @@ public Authentication authenticate(Authentication authentication) throws Authent
122165
}
123166

124167
return new OAuth2ClientAuthenticationToken(registeredClient,
125-
clientAuthentication.getClientAuthenticationMethod(), clientAuthentication.getCredentials());
168+
clientAuthentiationMethod, clientAuthentication.getCredentials());
126169
}
127170

128171
@Override
@@ -193,7 +236,7 @@ private static boolean codeVerifierValid(String codeVerifier, String codeChallen
193236
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.SERVER_ERROR);
194237
}
195238

196-
private static void throwInvalidClient(String parameterName) {
239+
static void throwInvalidClient(String parameterName) {
197240
OAuth2Error error = new OAuth2Error(
198241
OAuth2ErrorCodes.INVALID_CLIENT,
199242
"Client authentication failed: " + parameterName,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
/*
2+
* Copyright 2020-2021 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.security.oauth2.server.authorization.authentication;
17+
18+
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
19+
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
20+
import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
21+
import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
22+
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
23+
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
24+
import org.springframework.security.oauth2.jwt.JwtDecoder;
25+
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
26+
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
27+
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
28+
import org.springframework.security.oauth2.server.authorization.config.ClientSettings;
29+
import org.springframework.security.oauth2.server.authorization.config.ProviderSettings;
30+
import org.springframework.util.Assert;
31+
32+
import javax.crypto.spec.SecretKeySpec;
33+
import java.nio.charset.StandardCharsets;
34+
import java.util.Collections;
35+
import java.util.HashMap;
36+
import java.util.Map;
37+
import java.util.Objects;
38+
import java.util.concurrent.ConcurrentHashMap;
39+
40+
/**
41+
* Creates JWT decoders for registered clients
42+
*
43+
* @author Rafal Lewczuk
44+
* @since 0.2.1
45+
*/
46+
final class RegisteredClientJwtAssertionDecoderFactory implements JwtDecoderFactory<RegisteredClient> {
47+
48+
private static final Map<String, String> JCA_ALGORITHM_MAPPINGS;
49+
50+
static {
51+
Map<String, String> mappings = new HashMap<>();
52+
mappings.put(JwsAlgorithms.HS256, "HmacSHA256");
53+
mappings.put(JwsAlgorithms.HS384, "HmacSHA384");
54+
mappings.put(JwsAlgorithms.HS512, "HmacSHA512");
55+
JCA_ALGORITHM_MAPPINGS = Collections.unmodifiableMap(mappings);
56+
}
57+
58+
private final Map<String, CachedJwtDecoder> cachedDecoders = new ConcurrentHashMap<>();
59+
private final ProviderSettings providerSettings;
60+
61+
RegisteredClientJwtAssertionDecoderFactory(ProviderSettings providerSettings) {
62+
Assert.notNull(providerSettings, "providerSettings cannot be null");
63+
this.providerSettings = providerSettings;
64+
}
65+
66+
@Override
67+
public JwtDecoder createDecoder(RegisteredClient registeredClient) {
68+
Assert.notNull(registeredClient, "registeredClient cannot be null");
69+
70+
JwsAlgorithm signingAlgorithm = registeredClient.getClientSettings().getTokenEndpointSigningAlgorithm();
71+
boolean isHmac = signingAlgorithm instanceof MacAlgorithm;
72+
73+
// No fancy .computerIfAbsent() calls as we need to check if client configuration has changed
74+
// also this is idempotent operation so we don't have to care about exact synchronization
75+
CachedJwtDecoder cachedDecoder = this.cachedDecoders.get(registeredClient.getClientId());
76+
if (cachedDecoder != null) {
77+
ClientSettings clientSettings = cachedDecoder.registeredClient.getClientSettings();
78+
ClientSettings cachedSettings = registeredClient.getClientSettings();
79+
if (isHmac) {
80+
if (Objects.equals(cachedSettings.getTokenEndpointSigningAlgorithm(), clientSettings.getTokenEndpointSigningAlgorithm())
81+
&& Objects.equals(cachedDecoder.registeredClient.getClientSecret(), registeredClient.getClientSecret())) {
82+
return cachedDecoder.jwtDecoder;
83+
}
84+
} else {
85+
if (Objects.equals(cachedSettings.getTokenEndpointSigningAlgorithm(), clientSettings.getTokenEndpointSigningAlgorithm())
86+
&& Objects.equals(cachedSettings.getJwkSetUrl(), clientSettings.getJwkSetUrl())) {
87+
return cachedDecoder.jwtDecoder;
88+
}
89+
}
90+
}
91+
92+
cachedDecoder = isHmac ? createClientSecretJwtDecoder(registeredClient, (MacAlgorithm) signingAlgorithm)
93+
: createPrivateKeyJwtDecoder(registeredClient, (SignatureAlgorithm) signingAlgorithm);
94+
95+
this.providerSettings.getTokenEndpoint();
96+
cachedDecoder.jwtDecoder.setJwtValidator(
97+
new RegisteredClientJwtAssertionValidator(registeredClient, this.providerSettings));
98+
this.cachedDecoders.put(registeredClient.getClientId(), cachedDecoder);
99+
return cachedDecoder.jwtDecoder;
100+
}
101+
102+
private CachedJwtDecoder createClientSecretJwtDecoder(RegisteredClient registeredClient, MacAlgorithm macAlgorithm) {
103+
NimbusJwtDecoder jwtDecoder = null;
104+
if (!registeredClient.getClientAuthenticationMethods().contains(ClientAuthenticationMethod.CLIENT_SECRET_JWT)) {
105+
OAuth2ClientAuthenticationProvider.throwInvalidClient(OAuth2ParameterNames.CLIENT_ASSERTION);
106+
}
107+
108+
if (registeredClient.getClientSecret() == null || registeredClient.getClientSecret().isEmpty()) {
109+
OAuth2ClientAuthenticationProvider.throwInvalidClient(OAuth2ParameterNames.CLIENT_ASSERTION);
110+
}
111+
112+
try {
113+
byte[] secret = registeredClient.getClientSecret().getBytes(StandardCharsets.UTF_8);
114+
String algorithmName = JCA_ALGORITHM_MAPPINGS.get(macAlgorithm.getName());
115+
SecretKeySpec secretKey = new SecretKeySpec(secret, algorithmName);
116+
NimbusJwtDecoder.SecretKeyJwtDecoderBuilder jwtDecoderBuilder = NimbusJwtDecoder.withSecretKey(secretKey)
117+
.macAlgorithm(macAlgorithm);
118+
jwtDecoder = jwtDecoderBuilder.build();
119+
} catch (Exception e) {
120+
OAuth2ClientAuthenticationProvider.throwInvalidClient(OAuth2ParameterNames.CLIENT_ASSERTION);
121+
}
122+
return new CachedJwtDecoder(jwtDecoder, registeredClient);
123+
}
124+
125+
private CachedJwtDecoder createPrivateKeyJwtDecoder(RegisteredClient registeredClient, SignatureAlgorithm signatureAlgorithm) {
126+
NimbusJwtDecoder jwtDecoder = null;
127+
if (!registeredClient.getClientAuthenticationMethods().contains(ClientAuthenticationMethod.PRIVATE_KEY_JWT)) {
128+
OAuth2ClientAuthenticationProvider.throwInvalidClient(OAuth2ParameterNames.CLIENT_ASSERTION);
129+
}
130+
131+
String jwksUrl = registeredClient.getClientSettings().getJwkSetUrl();
132+
if (jwksUrl == null) {
133+
OAuth2ClientAuthenticationProvider.throwInvalidClient(OAuth2ParameterNames.CLIENT_ASSERTION);
134+
}
135+
136+
try {
137+
NimbusJwtDecoder.JwkSetUriJwtDecoderBuilder jwtDecoderBuilder = NimbusJwtDecoder.withJwkSetUri(jwksUrl)
138+
.jwsAlgorithm(signatureAlgorithm);
139+
jwtDecoder = jwtDecoderBuilder.build();
140+
} catch (Exception e) {
141+
OAuth2ClientAuthenticationProvider.throwInvalidClient(OAuth2ParameterNames.CLIENT_ASSERTION);
142+
}
143+
return new CachedJwtDecoder(jwtDecoder, registeredClient);
144+
}
145+
146+
private static class CachedJwtDecoder {
147+
private final NimbusJwtDecoder jwtDecoder;
148+
private final RegisteredClient registeredClient;
149+
150+
CachedJwtDecoder(NimbusJwtDecoder jwtDecoder, RegisteredClient registeredClient) {
151+
this.jwtDecoder = jwtDecoder;
152+
this.registeredClient = registeredClient;
153+
}
154+
}
155+
}

0 commit comments

Comments
 (0)