Skip to content

Commit 502fa24

Browse files
committed
Polish gh-787
1 parent 4466cbe commit 502fa24

File tree

2 files changed

+20
-12
lines changed
  • oauth2-authorization-server/src

2 files changed

+20
-12
lines changed

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java

+9-5
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
2828
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
2929
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
30+
import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
3031
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
3132
import org.springframework.security.oauth2.jwt.JwsHeader;
3233
import org.springframework.security.oauth2.jwt.Jwt;
@@ -89,14 +90,15 @@ public Jwt generate(OAuth2TokenContext context) {
8990

9091
Instant issuedAt = Instant.now();
9192
Instant expiresAt;
92-
JwsHeader.Builder headersBuilder;
93+
JwsAlgorithm jwsAlgorithm = SignatureAlgorithm.RS256;
9394
if (OidcParameterNames.ID_TOKEN.equals(context.getTokenType().getValue())) {
9495
// TODO Allow configuration for ID Token time-to-live
9596
expiresAt = issuedAt.plus(30, ChronoUnit.MINUTES);
96-
headersBuilder = JwsHeader.with(registeredClient.getTokenSettings().getIdTokenSignatureAlgorithm());
97+
if (registeredClient.getTokenSettings().getIdTokenSignatureAlgorithm() != null) {
98+
jwsAlgorithm = registeredClient.getTokenSettings().getIdTokenSignatureAlgorithm();
99+
}
97100
} else {
98101
expiresAt = issuedAt.plus(registeredClient.getTokenSettings().getAccessTokenTimeToLive());
99-
headersBuilder = JwsHeader.with(SignatureAlgorithm.RS256);
100102
}
101103

102104
// @formatter:off
@@ -128,9 +130,11 @@ public Jwt generate(OAuth2TokenContext context) {
128130
}
129131
// @formatter:on
130132

133+
JwsHeader.Builder jwsHeaderBuilder = JwsHeader.with(jwsAlgorithm);
134+
131135
if (this.jwtCustomizer != null) {
132136
// @formatter:off
133-
JwtEncodingContext.Builder jwtContextBuilder = JwtEncodingContext.with(headersBuilder, claimsBuilder)
137+
JwtEncodingContext.Builder jwtContextBuilder = JwtEncodingContext.with(jwsHeaderBuilder, claimsBuilder)
134138
.registeredClient(context.getRegisteredClient())
135139
.principal(context.getPrincipal())
136140
.authorizationServerContext(context.getAuthorizationServerContext())
@@ -149,7 +153,7 @@ public Jwt generate(OAuth2TokenContext context) {
149153
this.jwtCustomizer.customize(jwtContext);
150154
}
151155

152-
JwsHeader jwsHeader = headersBuilder.build();
156+
JwsHeader jwsHeader = jwsHeaderBuilder.build();
153157
JwtClaimsSet claims = claimsBuilder.build();
154158

155159
Jwt jwt = this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, claims));

oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtGeneratorTests.java

+11-7
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,10 @@ public void generateWhenAccessTokenTypeThenReturnJwt() {
152152

153153
@Test
154154
public void generateWhenIdTokenTypeThenReturnJwt() {
155-
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build();
155+
RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
156+
.scope(OidcScopes.OPENID)
157+
.tokenSettings(TokenSettings.builder().idTokenSignatureAlgorithm(SignatureAlgorithm.ES256).build())
158+
.build();
156159
Map<String, Object> authenticationRequestAdditionalParameters = new HashMap<>();
157160
authenticationRequestAdditionalParameters.put(OidcParameterNames.NONCE, "nonce");
158161
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(
@@ -201,27 +204,28 @@ private void assertGeneratedTokenType(OAuth2TokenContext tokenContext) {
201204
ArgumentCaptor<JwtEncoderParameters> jwtEncoderParametersCaptor = ArgumentCaptor.forClass(JwtEncoderParameters.class);
202205
verify(this.jwtEncoder).encode(jwtEncoderParametersCaptor.capture());
203206

207+
JwsHeader jwsHeader = jwtEncoderParametersCaptor.getValue().getJwsHeader();
208+
if (OidcParameterNames.ID_TOKEN.equals(tokenContext.getTokenType().getValue())) {
209+
assertThat(jwsHeader.getAlgorithm()).isEqualTo(tokenContext.getRegisteredClient().getTokenSettings().getIdTokenSignatureAlgorithm());
210+
} else {
211+
assertThat(jwsHeader.getAlgorithm()).isEqualTo(SignatureAlgorithm.RS256);
212+
}
213+
204214
JwtClaimsSet jwtClaimsSet = jwtEncoderParametersCaptor.getValue().getClaims();
205215
assertThat(jwtClaimsSet.getIssuer().toExternalForm()).isEqualTo(tokenContext.getAuthorizationServerContext().getIssuer());
206216
assertThat(jwtClaimsSet.getSubject()).isEqualTo(tokenContext.getAuthorization().getPrincipalName());
207217
assertThat(jwtClaimsSet.getAudience()).containsExactly(tokenContext.getRegisteredClient().getClientId());
208218

209219
Instant issuedAt = Instant.now();
210220
Instant expiresAt;
211-
JwsHeader.Builder headersBuilder;
212221
if (tokenContext.getTokenType().equals(OAuth2TokenType.ACCESS_TOKEN)) {
213222
expiresAt = issuedAt.plus(tokenContext.getRegisteredClient().getTokenSettings().getAccessTokenTimeToLive());
214-
headersBuilder = JwsHeader.with(SignatureAlgorithm.RS256);
215223
} else {
216224
expiresAt = issuedAt.plus(30, ChronoUnit.MINUTES);
217-
headersBuilder = JwsHeader.with(tokenContext.getRegisteredClient().getTokenSettings().getIdTokenSignatureAlgorithm());
218225
}
219226
assertThat(jwtClaimsSet.getIssuedAt()).isBetween(issuedAt.minusSeconds(1), issuedAt.plusSeconds(1));
220227
assertThat(jwtClaimsSet.getExpiresAt()).isBetween(expiresAt.minusSeconds(1), expiresAt.plusSeconds(1));
221228

222-
JwsHeader jwsHeader = jwtEncoderParametersCaptor.getValue().getJwsHeader();
223-
assertThat(jwsHeader.getAlgorithm()).isEqualTo(headersBuilder.build().getAlgorithm());
224-
225229
if (tokenContext.getTokenType().equals(OAuth2TokenType.ACCESS_TOKEN)) {
226230
assertThat(jwtClaimsSet.getNotBefore()).isBetween(issuedAt.minusSeconds(1), issuedAt.plusSeconds(1));
227231

0 commit comments

Comments
 (0)