Skip to content

Commit aa985fa

Browse files
committed
Ignore endpoint override if awsShouldUseFips is passed in. Add test.
1 parent 30e5339 commit aa985fa

File tree

2 files changed

+70
-18
lines changed

2 files changed

+70
-18
lines changed

src/main/java/software/amazon/msk/auth/iam/internals/MSKCredentialProvider.java

+22-12
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ public class MSKCredentialProvider implements AwsCredentialsProvider, AutoClosea
9090
private static final String AWS_ROLE_SESSION_TOKEN = "awsRoleSessionToken";
9191
private static final String AWS_STS_REGION = "awsStsRegion";
9292
private static final String AWS_DEBUG_CREDS_KEY = "awsDebugCreds";
93+
private static final String AWS_SHOULD_USE_FIPS = "awsShouldUseFips";
9394
private static final String AWS_MAX_RETRIES = "awsMaxRetries";
9495
private static final String AWS_MAX_BACK_OFF_TIME_MS = "awsMaxBackOffTimeMs";
9596
private static final String GLOBAL_REGION = "aws-global";
@@ -264,6 +265,10 @@ public Boolean shouldDebugCreds() {
264265
return Optional.ofNullable(optionsMap.get(AWS_DEBUG_CREDS_KEY)).map(d -> d.equals("true")).orElse(false);
265266
}
266267

268+
public Boolean shouldUseFips() {
269+
return Optional.ofNullable(optionsMap.get(AWS_SHOULD_USE_FIPS)).map(d -> d.equals("true")).orElse(false);
270+
}
271+
267272
public String getStsRegion() {
268273
return Optional.ofNullable((String) optionsMap.get(AWS_STS_REGION))
269274
.orElse(GLOBAL_REGION);
@@ -295,9 +300,10 @@ public URI buildEndpointConfiguration(Region stsRegion) {
295300
}
296301
}
297302

298-
private StsClientBuilder getStsClientBuilder(Region stsRegion) {
303+
private StsClientBuilder getStsClientBuilder(Region stsRegion, Boolean shouldUseFips) {
299304
StsClientBuilder builder = StsClient.builder().region(stsRegion);
300-
if (stsRegion != Region.AWS_GLOBAL) {
305+
if (stsRegion != Region.AWS_GLOBAL && !shouldUseFips) {
306+
log.info("Using STS Endpoint override");
301307
builder.endpointOverride(buildEndpointConfiguration(stsRegion));
302308
}
303309
return builder;
@@ -327,6 +333,7 @@ private Optional<StsAssumeRoleCredentialsProvider> getStsRoleProvider() {
327333
String sessionName = Optional.ofNullable((String) optionsMap.get(AWS_ROLE_SESSION_KEY))
328334
.orElse("aws-msk-iam-auth");
329335
String stsRegion = getStsRegion();
336+
Boolean shouldUseFIPs = shouldUseFips();
330337

331338
String accessKey = (String) optionsMap.getOrDefault(AWS_ROLE_ACCESS_KEY_ID, null);
332339
String secretKey = (String) optionsMap.getOrDefault(AWS_ROLE_SECRET_ACCESS_KEY, null);
@@ -337,25 +344,26 @@ private Optional<StsAssumeRoleCredentialsProvider> getStsRoleProvider() {
337344
sessionToken != null
338345
? AwsSessionCredentials.create(accessKey, secretKey, sessionToken)
339346
: AwsBasicCredentials.create(accessKey, secretKey));
340-
return createSTSRoleCredentialProvider((String) p, sessionName, stsRegion, credentials);
347+
return createSTSRoleCredentialProvider((String) p, sessionName, stsRegion, credentials, shouldUseFIPs);
341348
}
342349
else if (externalId != null) {
343-
return createSTSRoleCredentialProvider((String) p, externalId, sessionName, stsRegion);
350+
return createSTSRoleCredentialProvider((String) p, externalId, sessionName, stsRegion, shouldUseFIPs);
344351
}
345352

346-
return createSTSRoleCredentialProvider((String) p, sessionName, stsRegion);
353+
return createSTSRoleCredentialProvider((String) p, sessionName, stsRegion, shouldUseFIPs);
347354
});
348355
}
349356

350357
StsAssumeRoleCredentialsProvider createSTSRoleCredentialProvider(
351358
String roleArn,
352359
String sessionName,
353-
String stsRegion) {
360+
String stsRegion,
361+
Boolean shouldUseFips) {
354362
AssumeRoleRequest roleRequest = AssumeRoleRequest.builder()
355363
.roleArn(roleArn)
356364
.roleSessionName(sessionName)
357365
.build();
358-
StsClient stsClient = getStsClientBuilder(Region.of(stsRegion))
366+
StsClient stsClient = getStsClientBuilder(Region.of(stsRegion), shouldUseFips)
359367
.build();
360368
return StsAssumeRoleCredentialsProvider.builder()
361369
.stsClient(stsClient)
@@ -367,12 +375,13 @@ StsAssumeRoleCredentialsProvider createSTSRoleCredentialProvider(
367375
StsAssumeRoleCredentialsProvider createSTSRoleCredentialProvider(
368376
String roleArn,
369377
String sessionName, String stsRegion,
370-
AwsCredentialsProvider credentials) {
378+
AwsCredentialsProvider credentials,
379+
Boolean shouldUseFips) {
371380
AssumeRoleRequest roleRequest = AssumeRoleRequest.builder()
372381
.roleArn(roleArn)
373382
.roleSessionName(sessionName)
374383
.build();
375-
StsClient stsClient = getStsClientBuilder(Region.of(stsRegion))
384+
StsClient stsClient = getStsClientBuilder(Region.of(stsRegion), shouldUseFips)
376385
.credentialsProvider(credentials)
377386
.build();
378387
return StsAssumeRoleCredentialsProvider.builder()
@@ -386,17 +395,18 @@ StsAssumeRoleCredentialsProvider createSTSRoleCredentialProvider(
386395
String roleArn,
387396
String externalId,
388397
String sessionName,
389-
String stsRegion) {
398+
String stsRegion,
399+
Boolean shouldUseFips) {
390400
AssumeRoleRequest roleRequest = AssumeRoleRequest.builder()
391401
.externalId(externalId)
392402
.roleArn(roleArn)
393403
.roleSessionName(sessionName)
394404
.build();
395405
return StsAssumeRoleCredentialsProvider.builder()
396-
.stsClient(getStsClientBuilder(Region.of(stsRegion)).build())
406+
.stsClient(getStsClientBuilder(Region.of(stsRegion), shouldUseFips).build())
397407
.refreshRequest(roleRequest)
398408
.asyncCredentialUpdateEnabled(true)
399409
.build();
400410
}
401411
}
402-
}
412+
}

src/test/java/software/amazon/msk/auth/iam/internals/MSKCredentialProviderTest.java

+48-6
Original file line numberDiff line numberDiff line change
@@ -353,9 +353,10 @@ public void testAwsRoleArnSessionNameAndStsRegion() {
353353

354354
MSKCredentialProvider.ProviderBuilder providerBuilder = new MSKCredentialProvider.ProviderBuilder(optionsMap) {
355355
StsAssumeRoleCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
356-
String sessionName, String stsRegion) {
356+
String sessionName, String stsRegion, Boolean shouldUseFips) {
357357
assertEquals(TEST_ROLE_ARN, roleArn);
358358
assertEquals(TEST_ROLE_SESSION_NAME, sessionName);
359+
assertEquals(false, shouldUseFips);
359360
assertEquals("eu-west-1", stsRegion);
360361
URI endpointConfiguration = buildEndpointConfiguration(Region.of(stsRegion));
361362
assertEquals("https://sts.eu-west-1.amazonaws.com", endpointConfiguration.toString());
@@ -372,6 +373,41 @@ StsAssumeRoleCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
372373
Mockito.verify(mockStsRoleProvider, times(1)).close();
373374
}
374375

376+
@Test
377+
public void testAwsRoleArnSessionNameAndStsRegionAndShouldUseFIPs() {
378+
StsAssumeRoleCredentialsProvider mockStsRoleProvider = Mockito
379+
.mock(StsAssumeRoleCredentialsProvider.class);
380+
Mockito.when(mockStsRoleProvider.resolveIdentity())
381+
.thenAnswer(i -> CompletableFuture.completedFuture(AwsSessionCredentials.create(ACCESS_KEY_VALUE, SECRET_KEY_VALUE, SESSION_TOKEN)));
382+
383+
Map<String, String> optionsMap = new HashMap<>();
384+
optionsMap.put(AWS_ROLE_ARN, TEST_ROLE_ARN);
385+
optionsMap.put("awsRoleSessionName", TEST_ROLE_SESSION_NAME);
386+
optionsMap.put("awsStsRegion", "eu-west-1");
387+
optionsMap.put("awsShouldUseFips", "true");
388+
389+
MSKCredentialProvider.ProviderBuilder providerBuilder = new MSKCredentialProvider.ProviderBuilder(optionsMap) {
390+
StsAssumeRoleCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
391+
String sessionName, String stsRegion, Boolean shouldUseFips) {
392+
assertEquals(TEST_ROLE_ARN, roleArn);
393+
assertEquals(TEST_ROLE_SESSION_NAME, sessionName);
394+
assertEquals("eu-west-1", stsRegion);
395+
assertEquals(true, shouldUseFips);
396+
URI endpointConfiguration = buildEndpointConfiguration(Region.of(stsRegion));
397+
assertEquals("https://sts.eu-west-1.amazonaws.com", endpointConfiguration.toString());
398+
return mockStsRoleProvider;
399+
}
400+
};
401+
MSKCredentialProvider provider = new MSKCredentialProvider(providerBuilder);
402+
assertFalse(provider.getShouldDebugCreds());
403+
404+
AwsCredentials credentials = provider.resolveCredentials();
405+
validateBasicSessionCredentials(credentials);
406+
407+
provider.close();
408+
Mockito.verify(mockStsRoleProvider, times(1)).close();
409+
}
410+
375411
@Test
376412
public void testAwsRoleArnSessionNameStsRegionAndExternalId() {
377413
StsAssumeRoleCredentialsProvider mockStsRoleProvider = Mockito
@@ -389,11 +425,13 @@ public void testAwsRoleArnSessionNameStsRegionAndExternalId() {
389425
StsAssumeRoleCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
390426
String externalId,
391427
String sessionName,
392-
String stsRegion) {
428+
String stsRegion,
429+
Boolean shouldUseFips) {
393430
assertEquals(TEST_ROLE_ARN, roleArn);
394431
assertEquals(TEST_ROLE_EXTERNAL_ID, externalId);
395432
assertEquals(TEST_ROLE_SESSION_NAME, sessionName);
396433
assertEquals("eu-west-1", stsRegion);
434+
assertEquals(false, shouldUseFips);
397435
URI endpointConfiguration = buildEndpointConfiguration(Region.of(stsRegion));
398436
assertEquals("https://sts.eu-west-1.amazonaws.com", endpointConfiguration.toString());
399437
return mockStsRoleProvider;
@@ -429,9 +467,10 @@ ProfileCredentialsProvider createEnhancedProfileCredentialsProvider(String profi
429467
}
430468

431469
StsAssumeRoleCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
432-
String sessionName, String stsRegion) {
470+
String sessionName, String stsRegion, Boolean shouldUseFips) {
433471
assertEquals(TEST_ROLE_ARN, roleArn);
434472
assertEquals("aws-msk-iam-auth", sessionName);
473+
assertEquals(false, shouldUseFips);
435474
return mockStsRoleProvider;
436475
}
437476
};
@@ -649,9 +688,10 @@ private MSKCredentialProvider.ProviderBuilder getProviderBuilder(StsAssumeRoleCr
649688
Map<String, String> optionsMap, String s) {
650689
return new MSKCredentialProvider.ProviderBuilder(optionsMap) {
651690
StsAssumeRoleCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
652-
String sessionName, String stsRegion) {
691+
String sessionName, String stsRegion, Boolean shouldUseFips) {
653692
assertEquals(TEST_ROLE_ARN, roleArn);
654693
assertEquals(s, sessionName);
694+
assertEquals(false, shouldUseFips);
655695
return mockStsRoleProvider;
656696
}
657697
};
@@ -662,9 +702,11 @@ private MSKCredentialProvider.ProviderBuilder getProviderBuilderWithCredentials(
662702
return new MSKCredentialProvider.ProviderBuilder(optionsMap) {
663703
StsAssumeRoleCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
664704
String sessionName, String stsRegion,
665-
AwsCredentialsProvider credentials) {
705+
AwsCredentialsProvider credentials,
706+
Boolean shouldUseFips) {
666707
assertEquals(TEST_ROLE_ARN, roleArn);
667708
assertEquals(s, sessionName);
709+
assertEquals(false, shouldUseFips);
668710
return mockStsRoleProvider;
669711
}
670712
};
@@ -740,4 +782,4 @@ private URL getProfileResourceURL() {
740782
return getClass().getClassLoader().getResource("profile_config_file");
741783
}
742784

743-
}
785+
}

0 commit comments

Comments
 (0)