Skip to content

Commit 72eafdf

Browse files
Kehrlannjgrandja
authored andcommitted
Implement Proof Key for Code Exchange (PKCE) RFC 7636
See https://tools.ietf.org/html/rfc7636 Closes spring-projectsgh-45
1 parent f4793da commit 72eafdf

13 files changed

+1008
-55
lines changed

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@
3838
*/
3939
public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthenticationToken {
4040
private static final long serialVersionUID = SpringSecurityCoreVersion2.SERIAL_VERSION_UID;
41-
private final RegisteredClient registeredClient;
42-
private final Authentication clientPrincipal;
41+
private RegisteredClient registeredClient;
42+
private Authentication clientPrincipal;
4343
private final OAuth2AccessToken accessToken;
4444

4545
/**

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java

+73-15
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.springframework.security.oauth2.core.OAuth2Error;
2424
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
2525
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
26+
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
2627
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
2728
import org.springframework.security.oauth2.jose.JoseHeader;
2829
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
@@ -33,15 +34,20 @@
3334
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
3435
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
3536
import org.springframework.security.oauth2.server.authorization.TokenType;
37+
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
3638
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
3739
import org.springframework.util.Assert;
3840
import org.springframework.util.StringUtils;
3941

42+
import java.nio.charset.StandardCharsets;
43+
import java.security.MessageDigest;
44+
import java.security.NoSuchAlgorithmException;
4045
import java.net.MalformedURLException;
4146
import java.net.URI;
4247
import java.net.URL;
4348
import java.time.Instant;
4449
import java.time.temporal.ChronoUnit;
50+
import java.util.Base64;
4551
import java.util.Collections;
4652

4753
/**
@@ -85,29 +91,30 @@ public Authentication authenticate(Authentication authentication) throws Authent
8591
(OAuth2AuthorizationCodeAuthenticationToken) authentication;
8692

8793
OAuth2ClientAuthenticationToken clientPrincipal = null;
94+
RegisteredClient registeredClient = null;
8895
if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(authorizationCodeAuthentication.getPrincipal().getClass())) {
8996
clientPrincipal = (OAuth2ClientAuthenticationToken) authorizationCodeAuthentication.getPrincipal();
90-
}
91-
if (clientPrincipal == null || !clientPrincipal.isAuthenticated()) {
97+
registeredClient = clientPrincipal.getRegisteredClient();
98+
} else if (StringUtils.hasText(authorizationCodeAuthentication.getClientId())) {
99+
// When the principal is a string, it is the clientId, REQUIRED for public clients
100+
String clientId = authorizationCodeAuthentication.getClientId();
101+
registeredClient = this.registeredClientRepository.findByClientId(clientId);
102+
if (registeredClient == null) {
103+
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
104+
}
105+
} else {
92106
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
93107
}
94108

95-
// TODO Authenticate public client
96-
// A client MAY use the "client_id" request parameter to identify itself
97-
// when sending requests to the token endpoint.
98-
// In the "authorization_code" "grant_type" request to the token endpoint,
99-
// an unauthenticated client MUST send its "client_id" to prevent itself
100-
// from inadvertently accepting a code intended for a client with a different "client_id".
101-
// This protects the client from substitution of the authentication code.
109+
if (clientPrincipal != null && !clientPrincipal.isAuthenticated()) {
110+
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
111+
}
102112

103113
OAuth2Authorization authorization = this.authorizationService.findByToken(
104114
authorizationCodeAuthentication.getCode(), TokenType.AUTHORIZATION_CODE);
105115
if (authorization == null) {
106116
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
107117
}
108-
if (!clientPrincipal.getRegisteredClient().getId().equals(authorization.getRegisteredClientId())) {
109-
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
110-
}
111118

112119
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
113120
OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
@@ -116,6 +123,35 @@ public Authentication authenticate(Authentication authentication) throws Authent
116123
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
117124
}
118125

126+
if (!registeredClient.getClientId().equals(authorizationRequest.getClientId())) {
127+
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
128+
}
129+
130+
131+
String codeChallenge;
132+
Object codeChallengeParameter = authorizationRequest
133+
.getAdditionalParameters()
134+
.get(PkceParameterNames.CODE_CHALLENGE);
135+
136+
if (codeChallengeParameter != null) {
137+
codeChallenge = (String) codeChallengeParameter;
138+
139+
String codeChallengeMethod = (String) authorizationRequest
140+
.getAdditionalParameters()
141+
.get(PkceParameterNames.CODE_CHALLENGE_METHOD);
142+
143+
String codeVerifier = (String) authorizationCodeAuthentication
144+
.getAdditionalParameters()
145+
.get(PkceParameterNames.CODE_VERIFIER);
146+
147+
if (!pkceCodeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) {
148+
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
149+
}
150+
} else if (registeredClient.getClientSettings().requireProofKey()){
151+
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
152+
}
153+
154+
119155
JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build();
120156

121157
// TODO Allow configuration for issuer claim
@@ -130,7 +166,7 @@ public Authentication authenticate(Authentication authentication) throws Authent
130166
JwtClaimsSet jwtClaimsSet = JwtClaimsSet.withClaims()
131167
.issuer(issuer)
132168
.subject(authorization.getPrincipalName())
133-
.audience(Collections.singletonList(clientPrincipal.getRegisteredClient().getClientId()))
169+
.audience(Collections.singletonList(registeredClient.getClientId()))
134170
.issuedAt(issuedAt)
135171
.expiresAt(expiresAt)
136172
.notBefore(issuedAt)
@@ -148,8 +184,30 @@ public Authentication authenticate(Authentication authentication) throws Authent
148184
.build();
149185
this.authorizationService.save(authorization);
150186

151-
return new OAuth2AccessTokenAuthenticationToken(
152-
clientPrincipal.getRegisteredClient(), clientPrincipal, accessToken);
187+
return clientPrincipal != null ?
188+
new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken) :
189+
new OAuth2AccessTokenAuthenticationToken(registeredClient, new OAuth2ClientAuthenticationToken(registeredClient), accessToken);
190+
}
191+
192+
private boolean pkceCodeVerifierValid(String codeVerifier, String codeChallenge, String codeChallengeMethod) {
193+
if (codeVerifier == null) {
194+
return false;
195+
} else if (codeChallengeMethod == null || codeChallengeMethod.equals("plain")) {
196+
return codeVerifier.equals(codeChallenge);
197+
} else if ("S256".equals(codeChallengeMethod)) {
198+
try {
199+
MessageDigest md = MessageDigest.getInstance("SHA-256");
200+
byte[] digest = md.digest(codeVerifier.getBytes(StandardCharsets.US_ASCII));
201+
String encodedVerifier = Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
202+
return codeChallenge.equals(encodedVerifier);
203+
} catch (NoSuchAlgorithmException e) {
204+
// It is unlikely that SHA-256 is not available on the server. If it is not available,
205+
// there will likely be bigger issues as well. We default to SERVER_ERROR.
206+
}
207+
}
208+
209+
// Unsupported algorithm should be caught in OAuth2AuthorizationEndpointFilter
210+
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR));
153211
}
154212

155213
@Override

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java

+37-5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.springframework.util.Assert;
2323

2424
import java.util.Collections;
25+
import java.util.Map;
2526

2627
/**
2728
* An {@link Authentication} implementation used for the OAuth 2.0 Authorization Code Grant.
@@ -35,26 +36,36 @@
3536
*/
3637
public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenticationToken {
3738
private static final long serialVersionUID = SpringSecurityCoreVersion2.SERIAL_VERSION_UID;
38-
private String code;
39+
private final String code;
3940
private Authentication clientPrincipal;
40-
private String clientId;
41-
private String redirectUri;
41+
private final String clientId;
42+
private final String redirectUri;
43+
private final Map<String, Object> additionalParameters;
4244

4345
/**
4446
* Constructs an {@code OAuth2AuthorizationCodeAuthenticationToken} using the provided parameters.
4547
*
4648
* @param code the authorization code
4749
* @param clientPrincipal the authenticated client principal
4850
* @param redirectUri the redirect uri
51+
* @param additionalParameters the additional parameters
4952
*/
5053
public OAuth2AuthorizationCodeAuthenticationToken(String code,
51-
Authentication clientPrincipal, @Nullable String redirectUri) {
54+
Authentication clientPrincipal, @Nullable String redirectUri,
55+
Map<String, Object> additionalParameters) {
5256
super(Collections.emptyList());
5357
Assert.hasText(code, "code cannot be empty");
5458
Assert.notNull(clientPrincipal, "clientPrincipal cannot be null");
5559
this.code = code;
5660
this.clientPrincipal = clientPrincipal;
5761
this.redirectUri = redirectUri;
62+
this.additionalParameters = Collections.unmodifiableMap(additionalParameters != null ? additionalParameters : Collections.emptyMap());
63+
64+
if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(this.clientPrincipal.getClass())) {
65+
this.clientId = (String) this.clientPrincipal.getPrincipal();
66+
} else {
67+
this.clientId = null;
68+
}
5869
}
5970

6071
/**
@@ -63,15 +74,18 @@ public OAuth2AuthorizationCodeAuthenticationToken(String code,
6374
* @param code the authorization code
6475
* @param clientId the client identifier
6576
* @param redirectUri the redirect uri
77+
* @param additionalParameters the additional parameters
6678
*/
6779
public OAuth2AuthorizationCodeAuthenticationToken(String code,
68-
String clientId, @Nullable String redirectUri) {
80+
String clientId, @Nullable String redirectUri,
81+
Map<String, Object> additionalParameters) {
6982
super(Collections.emptyList());
7083
Assert.hasText(code, "code cannot be empty");
7184
Assert.hasText(clientId, "clientId cannot be empty");
7285
this.code = code;
7386
this.clientId = clientId;
7487
this.redirectUri = redirectUri;
88+
this.additionalParameters = Collections.unmodifiableMap(additionalParameters != null ? additionalParameters : Collections.emptyMap());
7589
}
7690

7791
@Override
@@ -101,4 +115,22 @@ public String getCode() {
101115
public @Nullable String getRedirectUri() {
102116
return this.redirectUri;
103117
}
118+
119+
/**
120+
* Returns the additional parameters
121+
*
122+
* @return the additional parameters
123+
*/
124+
public Map<String, Object> getAdditionalParameters() {
125+
return this.additionalParameters;
126+
}
127+
128+
/**
129+
* Returns the client id
130+
*
131+
* @return the client id
132+
*/
133+
public @Nullable String getClientId() {
134+
return this.clientId;
135+
}
104136
}

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClient.java

-1
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,6 @@ public RegisteredClient build() {
367367
Assert.hasText(this.clientId, "clientId cannot be empty");
368368
Assert.notEmpty(this.authorizationGrantTypes, "authorizationGrantTypes cannot be empty");
369369
if (this.authorizationGrantTypes.contains(AuthorizationGrantType.AUTHORIZATION_CODE)) {
370-
Assert.hasText(this.clientSecret, "clientSecret cannot be empty");
371370
Assert.notEmpty(this.redirectUris, "redirectUris cannot be empty");
372371
}
373372
if (CollectionUtils.isEmpty(this.clientAuthenticationMethods)) {

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java

+35-2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
2929
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
3030
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
31+
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
3132
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
3233
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
3334
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
@@ -78,6 +79,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
7879
private final RequestMatcher authorizationEndpointMatcher;
7980
private final StringKeyGenerator codeGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
8081
private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
82+
private final String PKCE_ERROR_URI = "https://tools.ietf.org/html/rfc7636#section-4.4.1";
8183

8284
/**
8385
* Constructs an {@code OAuth2AuthorizationEndpointFilter} using the provided parameters.
@@ -174,6 +176,34 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
174176
return;
175177
}
176178

179+
// code_challenge (REQUIRED for public clients) - RFC 7636 (PKCE)
180+
String codeChallenge = parameters.getFirst(PkceParameterNames.CODE_CHALLENGE);
181+
if (StringUtils.hasText(codeChallenge)) {
182+
if (parameters.get(PkceParameterNames.CODE_CHALLENGE).size() != 1) {
183+
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE, PKCE_ERROR_URI);
184+
sendErrorResponse(request, response, error, stateParameter, redirectUri);
185+
return;
186+
}
187+
188+
if (parameters.get(PkceParameterNames.CODE_CHALLENGE_METHOD) != null &&
189+
parameters.get(PkceParameterNames.CODE_CHALLENGE_METHOD).size() > 1) {
190+
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD, PKCE_ERROR_URI);
191+
sendErrorResponse(request, response, error, stateParameter, redirectUri);
192+
return;
193+
}
194+
195+
String codeChallengeMethod = parameters.getFirst(PkceParameterNames.CODE_CHALLENGE_METHOD);
196+
if (codeChallengeMethod != null && !Arrays.asList("plain", "S256").contains(codeChallengeMethod)) {
197+
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD, PKCE_ERROR_URI);
198+
sendErrorResponse(request, response, error, stateParameter, redirectUri);
199+
return;
200+
}
201+
} else if (registeredClient.getClientSettings().requireProofKey()) {
202+
OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE, PKCE_ERROR_URI);
203+
sendErrorResponse(request, response, error, stateParameter, redirectUri);
204+
return;
205+
}
206+
177207
// ---------------
178208
// The request is valid - ensure the resource owner is authenticated
179209
// ---------------
@@ -245,8 +275,11 @@ private void sendErrorResponse(HttpServletRequest request, HttpServletResponse r
245275
}
246276

247277
private static OAuth2Error createError(String errorCode, String parameterName) {
248-
return new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName,
249-
"https://tools.ietf.org/html/rfc6749#section-4.1.2.1");
278+
return createError(errorCode, parameterName, "https://tools.ietf.org/html/rfc6749#section-4.1.2.1");
279+
}
280+
281+
private static OAuth2Error createError(String errorCode, String parameterName, String errorUri) {
282+
return new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, errorUri);
250283
}
251284

252285
private static boolean isPrincipalAuthenticated(Authentication principal) {

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java

+30-9
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,13 @@
3030
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
3131
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
3232
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
33+
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
3334
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
3435
import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
3536
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
3637
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;
3738
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationToken;
39+
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
3840
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientCredentialsAuthenticationToken;
3941
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
4042
import org.springframework.security.web.util.matcher.RequestMatcher;
@@ -54,6 +56,7 @@
5456
import java.util.HashSet;
5557
import java.util.Map;
5658
import java.util.Set;
59+
import java.util.stream.Collectors;
5760

5861
/**
5962
* A {@code Filter} for the OAuth 2.0 Authorization Code Grant,
@@ -198,14 +201,22 @@ public Authentication convert(HttpServletRequest request) {
198201
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
199202

200203
// client_id (REQUIRED)
201-
String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);
202-
Authentication clientPrincipal = null;
203-
if (StringUtils.hasText(clientId)) {
204-
if (parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) {
204+
Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
205+
String clientId = null;
206+
if (clientPrincipal == null ||
207+
!OAuth2ClientAuthenticationToken.class.isAssignableFrom(clientPrincipal.getClass())) {
208+
clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);
209+
if (!StringUtils.hasText(clientId) ||
210+
parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) {
205211
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID);
206212
}
207-
} else {
208-
clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
213+
214+
// code_verifier (REQUIRED for public clients)
215+
String codeVerifier = parameters.getFirst(PkceParameterNames.CODE_VERIFIER);
216+
if (!StringUtils.hasText(codeVerifier) ||
217+
parameters.get(PkceParameterNames.CODE_VERIFIER).size() != 1) {
218+
throwError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_VERIFIER);
219+
}
209220
}
210221

211222
// code (REQUIRED)
@@ -223,9 +234,19 @@ public Authentication convert(HttpServletRequest request) {
223234
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI);
224235
}
225236

226-
return clientPrincipal != null ?
227-
new OAuth2AuthorizationCodeAuthenticationToken(code, clientPrincipal, redirectUri) :
228-
new OAuth2AuthorizationCodeAuthenticationToken(code, clientId, redirectUri);
237+
Map<String, Object> additionalParameters = parameters
238+
.entrySet()
239+
.stream()
240+
.filter(e -> !e.getKey().equals(OAuth2ParameterNames.GRANT_TYPE) &&
241+
!e.getKey().equals(OAuth2ParameterNames.CLIENT_ID) &&
242+
!e.getKey().equals(OAuth2ParameterNames.CODE) &&
243+
!e.getKey().equals(OAuth2ParameterNames.REDIRECT_URI))
244+
.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().get(0)));
245+
246+
247+
return clientId != null ?
248+
new OAuth2AuthorizationCodeAuthenticationToken(code, clientId, redirectUri, additionalParameters) :
249+
new OAuth2AuthorizationCodeAuthenticationToken(code, clientPrincipal, redirectUri, additionalParameters);
229250
}
230251
}
231252

0 commit comments

Comments
 (0)