Skip to content

Commit 4d0c375

Browse files
authored
Refresh remote JWKs on all errors (elastic#42850)
It turns out that key rotation on the OP, can manifest as both a BadJWSException and a BadJOSEException in nimbus-jose-jwt. As such we cannot depend on matching only BadJWSExceptions to determine if we should poll the remote JWKs for an update. This has the side-effect that a remote JWKs source will be polled exactly one additional time too for errors that have to do with configuration, or for errors that might be caused by not synched clocks, forged JWTs, etc. ( These will throw a BadJWTException which extends BadJOSEException also )
1 parent 910573e commit 4d0c375

File tree

2 files changed

+28
-9
lines changed

2 files changed

+28
-9
lines changed

x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/oidc/OpenIdConnectAuthenticator.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import com.nimbusds.jose.jwk.JWKSet;
1313
import com.nimbusds.jose.jwk.source.JWKSource;
1414
import com.nimbusds.jose.proc.BadJOSEException;
15-
import com.nimbusds.jose.proc.BadJWSException;
1615
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
1716
import com.nimbusds.jose.proc.SecurityContext;
1817
import com.nimbusds.jose.util.IOUtils;
@@ -241,7 +240,7 @@ private void getUserClaims(@Nullable AccessToken accessToken, JWT idToken, Nonce
241240
}
242241
claimsListener.onResponse(enrichedVerifiedIdTokenClaims);
243242
}
244-
} catch (BadJWSException e) {
243+
} catch (BadJOSEException e) {
245244
// We only try to update the cached JWK set once if a remote source is used and
246245
// RSA or ECDSA is used for signatures
247246
if (shouldRetry
@@ -257,7 +256,7 @@ private void getUserClaims(@Nullable AccessToken accessToken, JWT idToken, Nonce
257256
} else {
258257
claimsListener.onFailure(new ElasticsearchSecurityException("Failed to parse or validate the ID Token", e));
259258
}
260-
} catch (com.nimbusds.oauth2.sdk.ParseException | ParseException | BadJOSEException | JOSEException e) {
259+
} catch (com.nimbusds.oauth2.sdk.ParseException | ParseException | JOSEException e) {
261260
claimsListener.onFailure(new ElasticsearchSecurityException("Failed to parse or validate the ID Token", e));
262261
}
263262
}

x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/oidc/OpenIdConnectAuthenticatorTests.java

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,11 @@ public void testImplicitFlowFailsWithExpiredToken() throws Exception {
320320
assertThat(e.getMessage(), containsString("Failed to parse or validate the ID Token"));
321321
assertThat(e.getCause(), instanceOf(BadJWTException.class));
322322
assertThat(e.getCause().getMessage(), containsString("Expired JWT"));
323-
assertThat(callsToReloadJwk, equalTo(0));
323+
if (jwk.getAlgorithm().getName().startsWith("HS")) {
324+
assertThat(callsToReloadJwk, equalTo(0));
325+
} else {
326+
assertThat(callsToReloadJwk, equalTo(1));
327+
}
324328
}
325329

326330
public void testImplicitFlowFailsNotYetIssuedToken() throws Exception {
@@ -360,7 +364,11 @@ public void testImplicitFlowFailsNotYetIssuedToken() throws Exception {
360364
assertThat(e.getMessage(), containsString("Failed to parse or validate the ID Token"));
361365
assertThat(e.getCause(), instanceOf(BadJWTException.class));
362366
assertThat(e.getCause().getMessage(), containsString("JWT issue time ahead of current time"));
363-
assertThat(callsToReloadJwk, equalTo(0));
367+
if (jwk.getAlgorithm().getName().startsWith("HS")) {
368+
assertThat(callsToReloadJwk, equalTo(0));
369+
} else {
370+
assertThat(callsToReloadJwk, equalTo(1));
371+
}
364372
}
365373

366374
public void testImplicitFlowFailsInvalidIssuer() throws Exception {
@@ -399,7 +407,11 @@ public void testImplicitFlowFailsInvalidIssuer() throws Exception {
399407
assertThat(e.getMessage(), containsString("Failed to parse or validate the ID Token"));
400408
assertThat(e.getCause(), instanceOf(BadJWTException.class));
401409
assertThat(e.getCause().getMessage(), containsString("Unexpected JWT issuer"));
402-
assertThat(callsToReloadJwk, equalTo(0));
410+
if (jwk.getAlgorithm().getName().startsWith("HS")) {
411+
assertThat(callsToReloadJwk, equalTo(0));
412+
} else {
413+
assertThat(callsToReloadJwk, equalTo(1));
414+
}
403415
}
404416

405417
public void testImplicitFlowFailsInvalidAudience() throws Exception {
@@ -438,7 +450,11 @@ public void testImplicitFlowFailsInvalidAudience() throws Exception {
438450
assertThat(e.getMessage(), containsString("Failed to parse or validate the ID Token"));
439451
assertThat(e.getCause(), instanceOf(BadJWTException.class));
440452
assertThat(e.getCause().getMessage(), containsString("Unexpected JWT audience"));
441-
assertThat(callsToReloadJwk, equalTo(0));
453+
if (jwk.getAlgorithm().getName().startsWith("HS")) {
454+
assertThat(callsToReloadJwk, equalTo(0));
455+
} else {
456+
assertThat(callsToReloadJwk, equalTo(1));
457+
}
442458
}
443459

444460
public void testAuthenticateImplicitFlowFailsWithForgedRsaIdToken() throws Exception {
@@ -611,7 +627,7 @@ public void testImplicitFlowFailsWithAlgorithmMixupAttack() throws Exception {
611627
assertThat(e.getMessage(), containsString("Failed to parse or validate the ID Token"));
612628
assertThat(e.getCause(), instanceOf(BadJOSEException.class));
613629
assertThat(e.getCause().getMessage(), containsString("Another algorithm expected, or no matching key(s) found"));
614-
assertThat(callsToReloadJwk, equalTo(0));
630+
assertThat(callsToReloadJwk, equalTo(1));
615631
}
616632

617633
public void testImplicitFlowFailsWithUnsignedJwt() throws Exception {
@@ -648,7 +664,11 @@ public void testImplicitFlowFailsWithUnsignedJwt() throws Exception {
648664
assertThat(e.getMessage(), containsString("Failed to parse or validate the ID Token"));
649665
assertThat(e.getCause(), instanceOf(BadJWTException.class));
650666
assertThat(e.getCause().getMessage(), containsString("Signed ID token expected"));
651-
assertThat(callsToReloadJwk, equalTo(0));
667+
if (jwk.getAlgorithm().getName().startsWith("HS")) {
668+
assertThat(callsToReloadJwk, equalTo(0));
669+
} else {
670+
assertThat(callsToReloadJwk, equalTo(1));
671+
}
652672
}
653673

654674
public void testJsonObjectMerging() throws Exception {

0 commit comments

Comments
 (0)