Skip to content

Commit 725897b

Browse files
committed
Implement Proof Key for Code Exchange (PKCE), RFC 7636
- See https://tools.ietf.org/html/rfc7636
1 parent c3b2545 commit 725897b

13 files changed

+950
-55
lines changed

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@
3838
*/
3939
public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthenticationToken {
4040
private static final long serialVersionUID = SpringSecurityCoreVersion2.SERIAL_VERSION_UID;
41-
private final RegisteredClient registeredClient;
42-
private final Authentication clientPrincipal;
41+
private RegisteredClient registeredClient;
42+
private Authentication clientPrincipal;
4343
private final OAuth2AccessToken accessToken;
4444

4545
/**

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java

+71-15
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.springframework.security.oauth2.core.OAuth2Error;
2424
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
2525
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
26+
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
2627
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
2728
import org.springframework.security.oauth2.jose.JoseHeader;
2829
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
@@ -33,15 +34,20 @@
3334
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
3435
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
3536
import org.springframework.security.oauth2.server.authorization.TokenType;
37+
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
3638
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
3739
import org.springframework.util.Assert;
3840
import org.springframework.util.StringUtils;
3941

42+
import java.nio.charset.StandardCharsets;
43+
import java.security.MessageDigest;
44+
import java.security.NoSuchAlgorithmException;
4045
import java.net.MalformedURLException;
4146
import java.net.URI;
4247
import java.net.URL;
4348
import java.time.Instant;
4449
import java.time.temporal.ChronoUnit;
50+
import java.util.Base64;
4551
import java.util.Collections;
4652

4753
/**
@@ -85,29 +91,30 @@ public Authentication authenticate(Authentication authentication) throws Authent
8591
(OAuth2AuthorizationCodeAuthenticationToken) authentication;
8692

8793
OAuth2ClientAuthenticationToken clientPrincipal = null;
94+
RegisteredClient registeredClient = null;
8895
if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(authorizationCodeAuthentication.getPrincipal().getClass())) {
8996
clientPrincipal = (OAuth2ClientAuthenticationToken) authorizationCodeAuthentication.getPrincipal();
90-
}
91-
if (clientPrincipal == null || !clientPrincipal.isAuthenticated()) {
97+
registeredClient = clientPrincipal.getRegisteredClient();
98+
} else if (StringUtils.hasText(authorizationCodeAuthentication.getClientId())) {
99+
// When the principal is a string, it is the clientId, REQUIRED for public clients
100+
String clientId = authorizationCodeAuthentication.getClientId();
101+
registeredClient = this.registeredClientRepository.findByClientId(clientId);
102+
if (registeredClient == null) {
103+
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
104+
}
105+
} else {
92106
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
93107
}
94108

95-
// TODO Authenticate public client
96-
// A client MAY use the "client_id" request parameter to identify itself
97-
// when sending requests to the token endpoint.
98-
// In the "authorization_code" "grant_type" request to the token endpoint,
99-
// an unauthenticated client MUST send its "client_id" to prevent itself
100-
// from inadvertently accepting a code intended for a client with a different "client_id".
101-
// This protects the client from substitution of the authentication code.
109+
if (clientPrincipal != null && !clientPrincipal.isAuthenticated()) {
110+
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
111+
}
102112

103113
OAuth2Authorization authorization = this.authorizationService.findByToken(
104114
authorizationCodeAuthentication.getCode(), TokenType.AUTHORIZATION_CODE);
105115
if (authorization == null) {
106116
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
107117
}
108-
if (!clientPrincipal.getRegisteredClient().getId().equals(authorization.getRegisteredClientId())) {
109-
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
110-
}
111118

112119
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
113120
OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
@@ -116,6 +123,35 @@ public Authentication authenticate(Authentication authentication) throws Authent
116123
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
117124
}
118125

126+
if (!registeredClient.getClientId().equals(authorizationRequest.getClientId())) {
127+
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
128+
}
129+
130+
131+
String codeChallenge;
132+
Object codeChallengeParameter = authorizationRequest
133+
.getAdditionalParameters()
134+
.get(PkceParameterNames.CODE_CHALLENGE);
135+
136+
if (codeChallengeParameter != null) {
137+
codeChallenge = (String) codeChallengeParameter;
138+
139+
String codeChallengeMethod = (String) authorizationRequest
140+
.getAdditionalParameters()
141+
.get(PkceParameterNames.CODE_CHALLENGE_METHOD);
142+
143+
String codeVerifier = (String) authorizationCodeAuthentication
144+
.getAdditionalParameters()
145+
.get(PkceParameterNames.CODE_VERIFIER);
146+
147+
if (!pkceCodeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) {
148+
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
149+
}
150+
} else if (registeredClient.getClientSettings().requireProofKey()){
151+
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
152+
}
153+
154+
119155
JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build();
120156

121157
// TODO Allow configuration for issuer claim
@@ -130,7 +166,7 @@ public Authentication authenticate(Authentication authentication) throws Authent
130166
JwtClaimsSet jwtClaimsSet = JwtClaimsSet.withClaims()
131167
.issuer(issuer)
132168
.subject(authorization.getPrincipalName())
133-
.audience(Collections.singletonList(clientPrincipal.getRegisteredClient().getClientId()))
169+
.audience(Collections.singletonList(registeredClient.getClientId()))
134170
.issuedAt(issuedAt)
135171
.expiresAt(expiresAt)
136172
.notBefore(issuedAt)
@@ -148,8 +184,28 @@ public Authentication authenticate(Authentication authentication) throws Authent
148184
.build();
149185
this.authorizationService.save(authorization);
150186

151-
return new OAuth2AccessTokenAuthenticationToken(
152-
clientPrincipal.getRegisteredClient(), clientPrincipal, accessToken);
187+
return clientPrincipal != null ?
188+
new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken) :
189+
new OAuth2AccessTokenAuthenticationToken(registeredClient, new OAuth2ClientAuthenticationToken(registeredClient), accessToken);
190+
}
191+
192+
private boolean pkceCodeVerifierValid(String codeVerifier, String codeChallenge, String codeChallengeMethod) {
193+
if (codeVerifier == null) {
194+
return false;
195+
} else if (codeChallengeMethod == null || codeChallengeMethod.equals("plain")) {
196+
return codeVerifier.equals(codeChallenge);
197+
} else if ("S256".equals(codeChallengeMethod)) {
198+
try {
199+
MessageDigest md = MessageDigest.getInstance("SHA-256");
200+
byte[] digest = md.digest(codeVerifier.getBytes(StandardCharsets.US_ASCII));
201+
String encodedVerifier = Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
202+
return codeChallenge.equals(encodedVerifier);
203+
} catch (NoSuchAlgorithmException e) {
204+
// Unsupported algorithm should be caught in OAuth2AuthorizationEndpointFilter
205+
}
206+
}
207+
208+
return false;
153209
}
154210

155211
@Override

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java

+37-5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.springframework.util.Assert;
2323

2424
import java.util.Collections;
25+
import java.util.Map;
2526

2627
/**
2728
* An {@link Authentication} implementation used for the OAuth 2.0 Authorization Code Grant.
@@ -35,26 +36,36 @@
3536
*/
3637
public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenticationToken {
3738
private static final long serialVersionUID = SpringSecurityCoreVersion2.SERIAL_VERSION_UID;
38-
private String code;
39+
private final String code;
3940
private Authentication clientPrincipal;
40-
private String clientId;
41-
private String redirectUri;
41+
private final String clientId;
42+
private final String redirectUri;
43+
private final Map<String, Object> additionalParameters;
4244

4345
/**
4446
* Constructs an {@code OAuth2AuthorizationCodeAuthenticationToken} using the provided parameters.
4547
*
4648
* @param code the authorization code
4749
* @param clientPrincipal the authenticated client principal
4850
* @param redirectUri the redirect uri
51+
* @param additionalParameters the additional parameters
4952
*/
5053
public OAuth2AuthorizationCodeAuthenticationToken(String code,
51-
Authentication clientPrincipal, @Nullable String redirectUri) {
54+
Authentication clientPrincipal, @Nullable String redirectUri,
55+
Map<String, Object> additionalParameters) {
5256
super(Collections.emptyList());
5357
Assert.hasText(code, "code cannot be empty");
5458
Assert.notNull(clientPrincipal, "clientPrincipal cannot be null");
5559
this.code = code;
5660
this.clientPrincipal = clientPrincipal;
5761
this.redirectUri = redirectUri;
62+
this.additionalParameters = Collections.unmodifiableMap(additionalParameters != null ? additionalParameters : Collections.emptyMap());
63+
64+
if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(this.clientPrincipal.getClass())) {
65+
this.clientId = (String) this.clientPrincipal.getPrincipal();
66+
} else {
67+
this.clientId = null;
68+
}
5869
}
5970

6071
/**
@@ -63,15 +74,18 @@ public OAuth2AuthorizationCodeAuthenticationToken(String code,
6374
* @param code the authorization code
6475
* @param clientId the client identifier
6576
* @param redirectUri the redirect uri
77+
* @param additionalParameters the additional parameters
6678
*/
6779
public OAuth2AuthorizationCodeAuthenticationToken(String code,
68-
String clientId, @Nullable String redirectUri) {
80+
String clientId, @Nullable String redirectUri,
81+
Map<String, Object> additionalParameters) {
6982
super(Collections.emptyList());
7083
Assert.hasText(code, "code cannot be empty");
7184
Assert.hasText(clientId, "clientId cannot be empty");
7285
this.code = code;
7386
this.clientId = clientId;
7487
this.redirectUri = redirectUri;
88+
this.additionalParameters = Collections.unmodifiableMap(additionalParameters != null ? additionalParameters : Collections.emptyMap());
7589
}
7690

7791
@Override
@@ -101,4 +115,22 @@ public String getCode() {
101115
public @Nullable String getRedirectUri() {
102116
return this.redirectUri;
103117
}
118+
119+
/**
120+
* Returns the additional parameters
121+
*
122+
* @return the additional parameters
123+
*/
124+
public Map<String, Object> getAdditionalParameters() {
125+
return this.additionalParameters;
126+
}
127+
128+
/**
129+
* Returns the client id
130+
*
131+
* @return the client id
132+
*/
133+
public @Nullable String getClientId() {
134+
return this.clientId;
135+
}
104136
}

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClient.java

-1
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,6 @@ public RegisteredClient build() {
367367
Assert.hasText(this.clientId, "clientId cannot be empty");
368368
Assert.notEmpty(this.authorizationGrantTypes, "authorizationGrantTypes cannot be empty");
369369
if (this.authorizationGrantTypes.contains(AuthorizationGrantType.AUTHORIZATION_CODE)) {
370-
Assert.hasText(this.clientSecret, "clientSecret cannot be empty");
371370
Assert.notEmpty(this.redirectUris, "redirectUris cannot be empty");
372371
}
373372
if (CollectionUtils.isEmpty(this.clientAuthenticationMethods)) {

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java

+36-2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
2929
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
3030
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
31+
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
3132
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
3233
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
3334
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
@@ -50,6 +51,7 @@
5051
import java.io.IOException;
5152
import java.util.Arrays;
5253
import java.util.Base64;
54+
import java.util.Collection;
5355
import java.util.Collections;
5456
import java.util.HashSet;
5557
import java.util.Set;
@@ -78,6 +80,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
7880
private final RequestMatcher authorizationEndpointMatcher;
7981
private final StringKeyGenerator codeGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
8082
private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
83+
private final String PKCE_ERROR_URI = "https://tools.ietf.org/html/rfc7636#section-4.4.1";
8184

8285
/**
8386
* Constructs an {@code OAuth2AuthorizationEndpointFilter} using the provided parameters.
@@ -174,6 +177,34 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
174177
return;
175178
}
176179

180+
// code_challenge (REQUIRED for public clients) - RFC 7636 (PKCE)
181+
String codeChallenge = parameters.getFirst(PkceParameterNames.CODE_CHALLENGE);
182+
if (StringUtils.hasText(codeChallenge)) {
183+
if (parameters.get(PkceParameterNames.CODE_CHALLENGE).size() != 1) {
184+
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE, PKCE_ERROR_URI);
185+
sendErrorResponse(request, response, error, stateParameter, redirectUri);
186+
return;
187+
}
188+
189+
if (parameters.get(PkceParameterNames.CODE_CHALLENGE_METHOD) != null &&
190+
parameters.get(PkceParameterNames.CODE_CHALLENGE_METHOD).size() > 1) {
191+
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD, PKCE_ERROR_URI);
192+
sendErrorResponse(request, response, error, stateParameter, redirectUri);
193+
return;
194+
}
195+
196+
String codeChallengeMethod = parameters.getFirst(PkceParameterNames.CODE_CHALLENGE_METHOD);
197+
if (codeChallengeMethod != null && !((Collection<String>) Arrays.asList("plain", "S256")).contains(codeChallengeMethod)) {
198+
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD, PKCE_ERROR_URI);
199+
sendErrorResponse(request, response, error, stateParameter, redirectUri);
200+
return;
201+
}
202+
} else if (registeredClient.getClientSettings().requireProofKey()) {
203+
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE, PKCE_ERROR_URI);
204+
sendErrorResponse(request, response, error, stateParameter, redirectUri);
205+
return;
206+
}
207+
177208
// ---------------
178209
// The request is valid - ensure the resource owner is authenticated
179210
// ---------------
@@ -245,8 +276,11 @@ private void sendErrorResponse(HttpServletRequest request, HttpServletResponse r
245276
}
246277

247278
private static OAuth2Error createError(String errorCode, String parameterName) {
248-
return new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName,
249-
"https://tools.ietf.org/html/rfc6749#section-4.1.2.1");
279+
return createError(errorCode, parameterName, "https://tools.ietf.org/html/rfc6749#section-4.1.2.1");
280+
}
281+
282+
private static OAuth2Error createError(String errorCode, String parameterName, String errorUri) {
283+
return new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, errorUri);
250284
}
251285

252286
private static boolean isPrincipalAuthenticated(Authentication principal) {

0 commit comments

Comments
 (0)