Skip to content

fix(client-sts): allow overwriting default role assumer http handler #2426

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 101 additions & 17 deletions clients/client-sts/defaultRoleAssumers.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,36 @@
// https://github.com/aws/aws-sdk-js-v3/blob/main/codegen/smithy-aws-typescript-codegen/src/main/resources/software/amazon/smithy/aws/typescript/codegen/sts-client-defaultRoleAssumers.spec.ts
import { HttpResponse } from "@aws-sdk/protocol-http";
import { Readable } from "stream";
const assumeRoleResponse = `<AssumeRoleResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">

const mockHandle = jest.fn().mockResolvedValue({
response: new HttpResponse({
statusCode: 200,
body: Readable.from([""]),
}),
});
jest.mock("@aws-sdk/node-http-handler", () => ({
NodeHttpHandler: jest.fn().mockImplementation(() => ({
destroy: () => {},
handle: mockHandle,
})),
streamCollector: jest.fn(),
}));

import { getDefaultRoleAssumer, getDefaultRoleAssumerWithWebIdentity } from "./defaultRoleAssumers";
import type { AssumeRoleCommandInput } from "./commands/AssumeRoleCommand";
import { NodeHttpHandler, streamCollector } from "@aws-sdk/node-http-handler";
import { AssumeRoleWithWebIdentityCommandInput } from "./commands/AssumeRoleWithWebIdentityCommand";
const mockConstructorInput = jest.fn();
jest.mock("./STSClient", () => ({
STSClient: function (params: any) {
mockConstructorInput(params);
//@ts-ignore
return new (jest.requireActual("./STSClient").STSClient)(params);
},
}));

describe("getDefaultRoleAssumer", () => {
const assumeRoleResponse = `<AssumeRoleResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<AssumeRoleResult>
<AssumedRoleUser>
<AssumedRoleId>AROAZOX2IL27GNRBJHWC2:session</AssumedRoleId>
Expand All @@ -19,27 +48,15 @@ const assumeRoleResponse = `<AssumeRoleResponse xmlns="https://sts.amazonaws.com
<RequestId>12345678id</RequestId>
</ResponseMetadata>
</AssumeRoleResponse>`;
const mockHandle = jest.fn().mockResolvedValue({
response: new HttpResponse({
statusCode: 200,
body: Readable.from([""]),
}),
});
jest.mock("@aws-sdk/node-http-handler", () => ({
NodeHttpHandler: jest.fn().mockImplementation(() => ({
destroy: () => {},
handle: mockHandle,
})),
streamCollector: async () => Buffer.from(assumeRoleResponse),
}));

import { getDefaultRoleAssumer } from "./defaultRoleAssumers";
import type { AssumeRoleCommandInput } from "./commands/AssumeRoleCommand";
beforeAll(() => {
(streamCollector as jest.Mock).mockImplementation(async () => Buffer.from(assumeRoleResponse));
});

describe("getDefaultRoleAssumer", () => {
beforeEach(() => {
jest.clearAllMocks();
});

it("should use supplied source credentials", async () => {
const roleAssumer = getDefaultRoleAssumer();
const params: AssumeRoleCommandInput = {
Expand All @@ -61,4 +78,71 @@ describe("getDefaultRoleAssumer", () => {
expect.stringContaining("AWS4-HMAC-SHA256 Credential=key2/")
);
});

it("should use the STS client config", async () => {
const logger = console;
const region = "some-region";
const handler = new NodeHttpHandler();
const roleAssumer = getDefaultRoleAssumer({
region,
logger,
requestHandler: handler,
});
const params: AssumeRoleCommandInput = {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
};
const sourceCred = { accessKeyId: "key", secretAccessKey: "secrete" };
await roleAssumer(sourceCred, params);
expect(mockConstructorInput).toHaveBeenCalledTimes(1);
expect(mockConstructorInput.mock.calls[0][0]).toMatchObject({
logger,
requestHandler: handler,
region,
});
});
});

describe("getDefaultRoleAssumerWithWebIdentity", () => {
const assumeRoleResponse = `<Response xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<AssumeRoleWithWebIdentityResult>
<Credentials>
<AccessKeyId>key</AccessKeyId>
<SecretAccessKey>secrete</SecretAccessKey>
<SessionToken>session-token</SessionToken>
<Expiration>2021-05-05T23:22:08Z</Expiration>
</Credentials>
</AssumeRoleWithWebIdentityResult>
</Response>`;

beforeAll(() => {
(streamCollector as jest.Mock).mockImplementation(async () => Buffer.from(assumeRoleResponse));
});

beforeEach(() => {
jest.clearAllMocks();
});

it("should use the STS client config", async () => {
const logger = console;
const region = "some-region";
const handler = new NodeHttpHandler();
const roleAssumerWithWebIdentity = getDefaultRoleAssumerWithWebIdentity({
region,
logger,
requestHandler: handler,
});
const params: AssumeRoleWithWebIdentityCommandInput = {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
WebIdentityToken: "token",
};
await roleAssumerWithWebIdentity(params);
expect(mockConstructorInput).toHaveBeenCalledTimes(1);
expect(mockConstructorInput.mock.calls[0][0]).toMatchObject({
logger,
requestHandler: handler,
region,
});
});
});
7 changes: 4 additions & 3 deletions clients/client-sts/defaultRoleAssumers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@ import { STSClient, STSClientConfig } from "./STSClient";
/**
* The default role assumer that used by credential providers when sts:AssumeRole API is needed.
*/
export const getDefaultRoleAssumer = (stsOptions: Pick<STSClientConfig, "logger" | "region"> = {}): RoleAssumer =>
StsGetDefaultRoleAssumer(stsOptions, STSClient);
export const getDefaultRoleAssumer = (
stsOptions: Pick<STSClientConfig, "logger" | "region" | "requestHandler"> = {}
): RoleAssumer => StsGetDefaultRoleAssumer(stsOptions, STSClient);

/**
* The default role assumer that used by credential providers when sts:AssumeRoleWithWebIdentity API is needed.
*/
export const getDefaultRoleAssumerWithWebIdentity = (
stsOptions: Pick<STSClientConfig, "logger" | "region"> = {}
stsOptions: Pick<STSClientConfig, "logger" | "region" | "requestHandler"> = {}
): RoleAssumerWithWebIdentity => StsGetDefaultRoleAssumerWithWebIdentity(stsOptions, STSClient);

/**
Expand Down
14 changes: 8 additions & 6 deletions clients/client-sts/defaultStsRoleAssumers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,21 @@ const decorateDefaultRegion = (region: string | Provider<string> | undefined): s
* @internal
*/
export const getDefaultRoleAssumer = (
stsOptions: Pick<STSClientConfig, "logger" | "region">,
stsOptions: Pick<STSClientConfig, "logger" | "region" | "requestHandler">,
stsClientCtor: new (options: STSClientConfig) => STSClient
): RoleAssumer => {
let stsClient: STSClient;
let closureSourceCreds: Credentials;
return async (sourceCreds, params) => {
closureSourceCreds = sourceCreds;
if (!stsClient) {
const { logger, region } = stsOptions;
const { logger, region, requestHandler } = stsOptions;
stsClient = new stsClientCtor({
logger,
// A hack to make sts client uses the credential in current closure.
credentialDefaultProvider: () => async () => closureSourceCreds,
region: decorateDefaultRegion(region),
region: decorateDefaultRegion(region || stsOptions.region),
...(requestHandler ? { requestHandler } : {}),
});
}
const { Credentials } = await stsClient.send(new AssumeRoleCommand(params));
Expand All @@ -76,16 +77,17 @@ export type RoleAssumerWithWebIdentity = (params: AssumeRoleWithWebIdentityComma
* @internal
*/
export const getDefaultRoleAssumerWithWebIdentity = (
stsOptions: Pick<STSClientConfig, "logger" | "region">,
stsOptions: Pick<STSClientConfig, "logger" | "region" | "requestHandler">,
stsClientCtor: new (options: STSClientConfig) => STSClient
): RoleAssumerWithWebIdentity => {
let stsClient: STSClient;
return async (params) => {
if (!stsClient) {
const { logger, region } = stsOptions;
const { logger, region, requestHandler } = stsOptions;
stsClient = new stsClientCtor({
logger,
region: decorateDefaultRegion(region),
region: decorateDefaultRegion(region || stsOptions.region),
...(requestHandler ? { requestHandler } : {}),
});
}
const { Credentials } = await stsClient.send(new AssumeRoleWithWebIdentityCommand(params));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,35 @@
import { HttpResponse } from "@aws-sdk/protocol-http";
import { Readable } from "stream";
const assumeRoleResponse = `<AssumeRoleResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">

const mockHandle = jest.fn().mockResolvedValue({
response: new HttpResponse({
statusCode: 200,
body: Readable.from([""]),
}),
});
jest.mock("@aws-sdk/node-http-handler", () => ({
NodeHttpHandler: jest.fn().mockImplementation(() => ({
destroy: () => {},
handle: mockHandle,
})),
streamCollector: jest.fn(),
}));

import { getDefaultRoleAssumer, getDefaultRoleAssumerWithWebIdentity } from "./defaultRoleAssumers";
import type { AssumeRoleCommandInput } from "./commands/AssumeRoleCommand";
import { NodeHttpHandler, streamCollector } from "@aws-sdk/node-http-handler";
import { AssumeRoleWithWebIdentityCommandInput } from "./commands/AssumeRoleWithWebIdentityCommand";
const mockConstructorInput = jest.fn();
jest.mock("./STSClient", () => ({
STSClient: function (params: any) {
mockConstructorInput(params);
//@ts-ignore
return new (jest.requireActual("./STSClient").STSClient)(params);
},
}));

describe("getDefaultRoleAssumer", () => {
const assumeRoleResponse = `<AssumeRoleResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<AssumeRoleResult>
<AssumedRoleUser>
<AssumedRoleId>AROAZOX2IL27GNRBJHWC2:session</AssumedRoleId>
Expand All @@ -17,27 +46,15 @@ const assumeRoleResponse = `<AssumeRoleResponse xmlns="https://sts.amazonaws.com
<RequestId>12345678id</RequestId>
</ResponseMetadata>
</AssumeRoleResponse>`;
const mockHandle = jest.fn().mockResolvedValue({
response: new HttpResponse({
statusCode: 200,
body: Readable.from([""]),
}),
});
jest.mock("@aws-sdk/node-http-handler", () => ({
NodeHttpHandler: jest.fn().mockImplementation(() => ({
destroy: () => {},
handle: mockHandle,
})),
streamCollector: async () => Buffer.from(assumeRoleResponse),
}));

import { getDefaultRoleAssumer } from "./defaultRoleAssumers";
import type { AssumeRoleCommandInput } from "./commands/AssumeRoleCommand";
beforeAll(() => {
(streamCollector as jest.Mock).mockImplementation(async () => Buffer.from(assumeRoleResponse));
});

describe("getDefaultRoleAssumer", () => {
beforeEach(() => {
jest.clearAllMocks();
});

it("should use supplied source credentials", async () => {
const roleAssumer = getDefaultRoleAssumer();
const params: AssumeRoleCommandInput = {
Expand All @@ -59,4 +76,71 @@ describe("getDefaultRoleAssumer", () => {
expect.stringContaining("AWS4-HMAC-SHA256 Credential=key2/")
);
});

it("should use the STS client config", async () => {
const logger = console;
const region = "some-region";
const handler = new NodeHttpHandler();
const roleAssumer = getDefaultRoleAssumer({
region,
logger,
requestHandler: handler,
});
const params: AssumeRoleCommandInput = {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
};
const sourceCred = { accessKeyId: "key", secretAccessKey: "secrete" };
await roleAssumer(sourceCred, params);
expect(mockConstructorInput).toHaveBeenCalledTimes(1);
expect(mockConstructorInput.mock.calls[0][0]).toMatchObject({
logger,
requestHandler: handler,
region,
});
});
});

describe("getDefaultRoleAssumerWithWebIdentity", () => {
const assumeRoleResponse = `<Response xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<AssumeRoleWithWebIdentityResult>
<Credentials>
<AccessKeyId>key</AccessKeyId>
<SecretAccessKey>secrete</SecretAccessKey>
<SessionToken>session-token</SessionToken>
<Expiration>2021-05-05T23:22:08Z</Expiration>
</Credentials>
</AssumeRoleWithWebIdentityResult>
</Response>`;

beforeAll(() => {
(streamCollector as jest.Mock).mockImplementation(async () => Buffer.from(assumeRoleResponse));
});

beforeEach(() => {
jest.clearAllMocks();
});

it("should use the STS client config", async () => {
const logger = console;
const region = "some-region";
const handler = new NodeHttpHandler();
const roleAssumerWithWebIdentity = getDefaultRoleAssumerWithWebIdentity({
region,
logger,
requestHandler: handler,
});
const params: AssumeRoleWithWebIdentityCommandInput = {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
WebIdentityToken: "token",
};
await roleAssumerWithWebIdentity(params);
expect(mockConstructorInput).toHaveBeenCalledTimes(1);
expect(mockConstructorInput.mock.calls[0][0]).toMatchObject({
logger,
requestHandler: handler,
region,
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ import { STSClient, STSClientConfig } from "./STSClient";
/**
* The default role assumer that used by credential providers when sts:AssumeRole API is needed.
*/
export const getDefaultRoleAssumer = (stsOptions: Pick<STSClientConfig, "logger" | "region"> = {}): RoleAssumer =>
StsGetDefaultRoleAssumer(stsOptions, STSClient);
export const getDefaultRoleAssumer = (
stsOptions: Pick<STSClientConfig, "logger" | "region" | "requestHandler"> = {}
): RoleAssumer => StsGetDefaultRoleAssumer(stsOptions, STSClient);

/**
* The default role assumer that used by credential providers when sts:AssumeRoleWithWebIdentity API is needed.
*/
export const getDefaultRoleAssumerWithWebIdentity = (
stsOptions: Pick<STSClientConfig, "logger" | "region"> = {}
stsOptions: Pick<STSClientConfig, "logger" | "region" | "requestHandler"> = {}
): RoleAssumerWithWebIdentity => StsGetDefaultRoleAssumerWithWebIdentity(stsOptions, STSClient);

/**
Expand Down
Loading