|
17 | 17 |
|
18 | 18 | import java.security.interfaces.RSAPublicKey;
|
19 | 19 | import java.time.Instant;
|
| 20 | +import java.util.Collections; |
20 | 21 | import java.util.LinkedHashMap;
|
21 | 22 | import java.util.List;
|
22 | 23 | import java.util.Map;
|
|
40 | 41 | import com.nimbusds.jwt.proc.JWTProcessor;
|
41 | 42 | import reactor.core.publisher.Mono;
|
42 | 43 |
|
| 44 | +import org.springframework.core.convert.converter.Converter; |
43 | 45 | import org.springframework.security.oauth2.core.OAuth2TokenValidator;
|
44 | 46 | import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
|
45 | 47 | import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
|
@@ -70,6 +72,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
|
70 | 72 | private final JWKSelectorFactory jwkSelectorFactory;
|
71 | 73 |
|
72 | 74 | private OAuth2TokenValidator<Jwt> jwtValidator = JwtValidators.createDefault();
|
| 75 | + private Converter<Map<String, Object>, Map<String, Object>> claimSetConverter = MappedJwtClaimSetConverter |
| 76 | + .withDefaults(Collections.emptyMap()); |
73 | 77 |
|
74 | 78 | public NimbusReactiveJwtDecoder(RSAPublicKey publicKey) {
|
75 | 79 | JWSAlgorithm algorithm = JWSAlgorithm.parse(JwsAlgorithms.RS256);
|
@@ -122,6 +126,16 @@ public void setJwtValidator(OAuth2TokenValidator<Jwt> jwtValidator) {
|
122 | 126 | this.jwtValidator = jwtValidator;
|
123 | 127 | }
|
124 | 128 |
|
| 129 | + /** |
| 130 | + * Use the following {@link Converter} for manipulating the JWT's claim set |
| 131 | + * |
| 132 | + * @param claimSetConverter the {@link Converter} to use |
| 133 | + */ |
| 134 | + public void setClaimSetConverter(Converter<Map<String, Object>, Map<String, Object>> claimSetConverter) { |
| 135 | + Assert.notNull(claimSetConverter, "claimSetConverter cannot be null"); |
| 136 | + this.claimSetConverter = claimSetConverter; |
| 137 | + } |
| 138 | + |
125 | 139 | @Override
|
126 | 140 | public Mono<Jwt> decode(String token) throws JwtException {
|
127 | 141 | JWT jwt = parse(token);
|
@@ -164,21 +178,12 @@ private JWTClaimsSet createClaimsSet(JWT parsedToken, List<JWK> jwkList) {
|
164 | 178 | }
|
165 | 179 |
|
166 | 180 | private Jwt createJwt(JWT parsedJwt, JWTClaimsSet jwtClaimsSet) {
|
167 |
| - Instant expiresAt = null; |
168 |
| - if (jwtClaimsSet.getExpirationTime() != null) { |
169 |
| - expiresAt = jwtClaimsSet.getExpirationTime().toInstant(); |
170 |
| - } |
171 |
| - Instant issuedAt = null; |
172 |
| - if (jwtClaimsSet.getIssueTime() != null) { |
173 |
| - issuedAt = jwtClaimsSet.getIssueTime().toInstant(); |
174 |
| - } else if (expiresAt != null) { |
175 |
| - // Default to expiresAt - 1 second |
176 |
| - issuedAt = Instant.from(expiresAt).minusSeconds(1); |
177 |
| - } |
178 |
| - |
179 | 181 | Map<String, Object> headers = new LinkedHashMap<>(parsedJwt.getHeader().toJSONObject());
|
| 182 | + Map<String, Object> claims = this.claimSetConverter.convert(jwtClaimsSet.getClaims()); |
180 | 183 |
|
181 |
| - return new Jwt(parsedJwt.getParsedString(), issuedAt, expiresAt, headers, jwtClaimsSet.getClaims()); |
| 184 | + Instant expiresAt = (Instant) claims.get(JwtClaimNames.EXP); |
| 185 | + Instant issuedAt = (Instant) claims.get(JwtClaimNames.IAT); |
| 186 | + return new Jwt(parsedJwt.getParsedString(), issuedAt, expiresAt, headers, claims); |
182 | 187 | }
|
183 | 188 |
|
184 | 189 | private Jwt validateJwt(Jwt jwt) {
|
|
0 commit comments