Skip to content

Commit ae74f22

Browse files
jzheauxrwinch
authored andcommitted
Reactive Jwt Claim Set Converter Support
Exposes setClaimSetConverter on NimbusReactiveJwtDecoder, lining it up with the same support on NimbusJwtDecoder. Fixes: gh-6015
1 parent 11b6b63 commit ae74f22

File tree

2 files changed

+43
-15
lines changed

2 files changed

+43
-15
lines changed

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java

+18-13
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import java.security.interfaces.RSAPublicKey;
1919
import java.time.Instant;
20+
import java.util.Collections;
2021
import java.util.LinkedHashMap;
2122
import java.util.List;
2223
import java.util.Map;
@@ -40,6 +41,7 @@
4041
import com.nimbusds.jwt.proc.JWTProcessor;
4142
import reactor.core.publisher.Mono;
4243

44+
import org.springframework.core.convert.converter.Converter;
4345
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
4446
import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
4547
import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
@@ -70,6 +72,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
7072
private final JWKSelectorFactory jwkSelectorFactory;
7173

7274
private OAuth2TokenValidator<Jwt> jwtValidator = JwtValidators.createDefault();
75+
private Converter<Map<String, Object>, Map<String, Object>> claimSetConverter = MappedJwtClaimSetConverter
76+
.withDefaults(Collections.emptyMap());
7377

7478
public NimbusReactiveJwtDecoder(RSAPublicKey publicKey) {
7579
JWSAlgorithm algorithm = JWSAlgorithm.parse(JwsAlgorithms.RS256);
@@ -122,6 +126,16 @@ public void setJwtValidator(OAuth2TokenValidator<Jwt> jwtValidator) {
122126
this.jwtValidator = jwtValidator;
123127
}
124128

129+
/**
130+
* Use the following {@link Converter} for manipulating the JWT's claim set
131+
*
132+
* @param claimSetConverter the {@link Converter} to use
133+
*/
134+
public void setClaimSetConverter(Converter<Map<String, Object>, Map<String, Object>> claimSetConverter) {
135+
Assert.notNull(claimSetConverter, "claimSetConverter cannot be null");
136+
this.claimSetConverter = claimSetConverter;
137+
}
138+
125139
@Override
126140
public Mono<Jwt> decode(String token) throws JwtException {
127141
JWT jwt = parse(token);
@@ -164,21 +178,12 @@ private JWTClaimsSet createClaimsSet(JWT parsedToken, List<JWK> jwkList) {
164178
}
165179

166180
private Jwt createJwt(JWT parsedJwt, JWTClaimsSet jwtClaimsSet) {
167-
Instant expiresAt = null;
168-
if (jwtClaimsSet.getExpirationTime() != null) {
169-
expiresAt = jwtClaimsSet.getExpirationTime().toInstant();
170-
}
171-
Instant issuedAt = null;
172-
if (jwtClaimsSet.getIssueTime() != null) {
173-
issuedAt = jwtClaimsSet.getIssueTime().toInstant();
174-
} else if (expiresAt != null) {
175-
// Default to expiresAt - 1 second
176-
issuedAt = Instant.from(expiresAt).minusSeconds(1);
177-
}
178-
179181
Map<String, Object> headers = new LinkedHashMap<>(parsedJwt.getHeader().toJSONObject());
182+
Map<String, Object> claims = this.claimSetConverter.convert(jwtClaimsSet.getClaims());
180183

181-
return new Jwt(parsedJwt.getParsedString(), issuedAt, expiresAt, headers, jwtClaimsSet.getClaims());
184+
Instant expiresAt = (Instant) claims.get(JwtClaimNames.EXP);
185+
Instant issuedAt = (Instant) claims.get(JwtClaimNames.IAT);
186+
return new Jwt(parsedJwt.getParsedString(), issuedAt, expiresAt, headers, claims);
182187
}
183188

184189
private Jwt validateJwt(Jwt jwt) {

oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java

+25-2
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,18 @@
2020
import java.security.KeyFactory;
2121
import java.security.interfaces.RSAPublicKey;
2222
import java.security.spec.X509EncodedKeySpec;
23+
import java.time.Instant;
2324
import java.util.Base64;
24-
import java.util.Date;
25+
import java.util.Collections;
26+
import java.util.Map;
2527

2628
import okhttp3.mockwebserver.MockResponse;
2729
import okhttp3.mockwebserver.MockWebServer;
2830
import org.junit.After;
2931
import org.junit.Before;
3032
import org.junit.Test;
3133

34+
import org.springframework.core.convert.converter.Converter;
3235
import org.springframework.security.oauth2.core.OAuth2Error;
3336
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
3437
import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
@@ -37,6 +40,7 @@
3740
import static org.assertj.core.api.Assertions.assertThatCode;
3841
import static org.mockito.ArgumentMatchers.any;
3942
import static org.mockito.Mockito.mock;
43+
import static org.mockito.Mockito.verify;
4044
import static org.mockito.Mockito.when;
4145

4246
/**
@@ -115,7 +119,7 @@ public void decodeWhenIssuedAtThenSuccess() {
115119

116120
Jwt jwt = this.decoder.decode(withIssuedAt).block();
117121

118-
assertThat(jwt.getClaims().get(JwtClaimNames.IAT)).isEqualTo(new Date(1529942448000L));
122+
assertThat(jwt.getClaims().get(JwtClaimNames.IAT)).isEqualTo(Instant.ofEpochSecond(1529942448L));
119123
}
120124

121125
@Test
@@ -177,9 +181,28 @@ public void decodeWhenUsingCustomValidatorThenValidatorIsInvoked() {
177181
.hasMessageContaining("mock-description");
178182
}
179183

184+
@Test
185+
public void decodeWhenUsingSignedJwtThenReturnsClaimsGivenByClaimSetConverter() {
186+
Converter<Map<String, Object>, Map<String, Object>> claimSetConverter = mock(Converter.class);
187+
this.decoder.setClaimSetConverter(claimSetConverter);
188+
189+
when(claimSetConverter.convert(any(Map.class))).thenReturn(Collections.singletonMap("custom", "value"));
190+
191+
Jwt jwt = this.decoder.decode(this.messageReadToken).block();
192+
assertThat(jwt.getClaims().size()).isEqualTo(1);
193+
assertThat(jwt.getClaims().get("custom")).isEqualTo("value");
194+
verify(claimSetConverter).convert(any(Map.class));
195+
}
196+
180197
@Test
181198
public void setJwtValidatorWhenGivenNullThrowsIllegalArgumentException() {
182199
assertThatCode(() -> this.decoder.setJwtValidator(null))
183200
.isInstanceOf(IllegalArgumentException.class);
184201
}
202+
203+
@Test
204+
public void setClaimSetConverterWhenNullThrowsIllegalArgumentException() {
205+
assertThatCode(() -> this.decoder.setClaimSetConverter(null))
206+
.isInstanceOf(IllegalArgumentException.class);
207+
}
185208
}

0 commit comments

Comments
 (0)