Skip to content

feat(credential-provider-imds): support static stability #3402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { httpRequest } from "./remoteProvider/httpRequest";
import { fromImdsCredentials, ImdsCredentials } from "./remoteProvider/ImdsCredentials";

const mockHttpRequest = <any>httpRequest;
jest.mock("./remoteProvider/httpRequest", () => ({ httpRequest: jest.fn() }));
jest.mock("./remoteProvider/httpRequest");

const relativeUri = process.env[ENV_CMDS_RELATIVE_URI];
const fullUri = process.env[ENV_CMDS_FULL_URI];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ import { fromImdsCredentials, isImdsCredentials } from "./remoteProvider/ImdsCre
import { providerConfigFromInit } from "./remoteProvider/RemoteProviderInit";
import { retry } from "./remoteProvider/retry";
import { getInstanceMetadataEndpoint } from "./utils/getInstanceMetadataEndpoint";
import { staticStabilityProvider } from "./utils/staticStabilityProvider";

jest.mock("./remoteProvider/httpRequest");
jest.mock("./remoteProvider/ImdsCredentials");
jest.mock("./remoteProvider/retry");
jest.mock("./remoteProvider/RemoteProviderInit");
jest.mock("./utils/getInstanceMetadataEndpoint");
jest.mock("./utils/staticStabilityProvider");

describe("fromInstanceMetadata", () => {
const hostname = "127.0.0.1";
Expand Down Expand Up @@ -39,11 +41,12 @@ describe("fromInstanceMetadata", () => {
},
};

const ONE_HOUR_IN_FUTURE = new Date(Date.now() + 60 * 60 * 1000);
const mockImdsCreds = Object.freeze({
AccessKeyId: "foo",
SecretAccessKey: "bar",
Token: "baz",
Expiration: new Date().toISOString(),
Expiration: ONE_HOUR_IN_FUTURE.toISOString(),
});

const mockCreds = Object.freeze({
Expand All @@ -54,6 +57,7 @@ describe("fromInstanceMetadata", () => {
});

beforeEach(() => {
(staticStabilityProvider as jest.Mock).mockImplementation((input) => input);
(getInstanceMetadataEndpoint as jest.Mock).mockResolvedValue({ hostname });
(isImdsCredentials as unknown as jest.Mock).mockReturnValue(true);
(providerConfigFromInit as jest.Mock).mockReturnValue({
Expand Down Expand Up @@ -192,6 +196,19 @@ describe("fromInstanceMetadata", () => {
await expect(fromInstanceMetadata()()).rejects.toEqual(tokenError);
});

it("should call staticStabilityProvider with the credential loader", async () => {
(httpRequest as jest.Mock)
.mockResolvedValueOnce(mockToken)
.mockResolvedValueOnce(mockProfile)
.mockResolvedValueOnce(JSON.stringify(mockImdsCreds));

(retry as jest.Mock).mockImplementation((fn: any) => fn());
(fromImdsCredentials as jest.Mock).mockReturnValue(mockCreds);

await fromInstanceMetadata()();
expect(staticStabilityProvider as jest.Mock).toBeCalledTimes(1);
});

describe("disables fetching of token", () => {
beforeEach(() => {
(retry as jest.Mock).mockImplementation((fn: any) => fn());
Expand Down Expand Up @@ -268,47 +285,4 @@ describe("fromInstanceMetadata", () => {
await expect(fromInstanceMetadataFunc()).resolves.toEqual(mockCreds);
await expect(fromInstanceMetadataFunc()).resolves.toEqual(mockCreds);
});

describe("re-enables fetching of token", () => {
const error401 = Object.assign(new Error("error"), { statusCode: 401 });

beforeEach(() => {
const tokenError = new Error("TimeoutError");

(httpRequest as jest.Mock)
.mockRejectedValueOnce(tokenError)
.mockResolvedValueOnce(mockProfile)
.mockResolvedValueOnce(JSON.stringify(mockImdsCreds));

(retry as jest.Mock).mockImplementation((fn: any) => fn());
(fromImdsCredentials as jest.Mock).mockReturnValue(mockCreds);
});

it("when profile error with 401", async () => {
(httpRequest as jest.Mock)
.mockRejectedValueOnce(error401)
.mockResolvedValueOnce(mockToken)
.mockResolvedValueOnce(mockProfile)
.mockResolvedValueOnce(JSON.stringify(mockImdsCreds));

const fromInstanceMetadataFunc = fromInstanceMetadata();
await expect(fromInstanceMetadataFunc()).resolves.toEqual(mockCreds);
await expect(fromInstanceMetadataFunc()).rejects.toEqual(error401);
await expect(fromInstanceMetadataFunc()).resolves.toEqual(mockCreds);
});

it("when creds error with 401", async () => {
(httpRequest as jest.Mock)
.mockResolvedValueOnce(mockProfile)
.mockRejectedValueOnce(error401)
.mockResolvedValueOnce(mockToken)
.mockResolvedValueOnce(mockProfile)
.mockResolvedValueOnce(JSON.stringify(mockImdsCreds));

const fromInstanceMetadataFunc = fromInstanceMetadata();
await expect(fromInstanceMetadataFunc()).resolves.toEqual(mockCreds);
await expect(fromInstanceMetadataFunc()).rejects.toEqual(error401);
await expect(fromInstanceMetadataFunc()).resolves.toEqual(mockCreds);
});
});
});
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import { CredentialsProviderError } from "@aws-sdk/property-provider";
import { CredentialProvider, Credentials } from "@aws-sdk/types";
import { Credentials, Provider } from "@aws-sdk/types";
import { RequestOptions } from "http";

import { httpRequest } from "./remoteProvider/httpRequest";
import { fromImdsCredentials, isImdsCredentials } from "./remoteProvider/ImdsCredentials";
import { providerConfigFromInit, RemoteProviderInit } from "./remoteProvider/RemoteProviderInit";
import { retry } from "./remoteProvider/retry";
import { InstanceMetadataCredentials } from "./types";
import { getInstanceMetadataEndpoint } from "./utils/getInstanceMetadataEndpoint";
import { staticStabilityProvider } from "./utils/staticStabilityProvider";

const IMDS_PATH = "/latest/meta-data/iam/security-credentials/";
const IMDS_TOKEN_PATH = "/latest/api/token";
Expand All @@ -15,7 +17,10 @@ const IMDS_TOKEN_PATH = "/latest/api/token";
* Creates a credential provider that will source credentials from the EC2
* Instance Metadata Service
*/
export const fromInstanceMetadata = (init: RemoteProviderInit = {}): CredentialProvider => {
export const fromInstanceMetadata = (init: RemoteProviderInit = {}): Provider<InstanceMetadataCredentials> =>
staticStabilityProvider(getInstanceImdsProvider(init));

const getInstanceImdsProvider = (init: RemoteProviderInit) => {
// when set to true, metadata service will not fetch token
let disableFetchToken = false;
const { timeout, maxRetries } = providerConfigFromInit(init);
Expand Down
1 change: 1 addition & 0 deletions packages/credential-provider-imds/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
export * from "./fromContainerMetadata";
export * from "./fromInstanceMetadata";
export * from "./remoteProvider/RemoteProviderInit";
export * from "./types";
export { httpRequest } from "./remoteProvider/httpRequest";
export { getInstanceMetadataEndpoint } from "./utils/getInstanceMetadataEndpoint";
5 changes: 5 additions & 0 deletions packages/credential-provider-imds/src/types.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import { Credentials } from "@aws-sdk/types";

export interface InstanceMetadataCredentials extends Credentials {
readonly originalExpiration?: Date;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import { getExtendedInstanceMetadataCredentials } from "./getExtendedInstanceMetadataCredentials";

describe("getExtendedInstanceMetadataCredentials()", () => {
let nowMock: jest.SpyInstance;
const staticSecret = {
accessKeyId: "key",
secretAccessKey: "secret",
};

beforeEach(() => {
jest.spyOn(global.console, "warn").mockImplementation(() => {});
jest.spyOn(global.Math, "random");
nowMock = jest.spyOn(Date, "now").mockReturnValueOnce(new Date("2022-02-22T00:00:00Z").getTime());
});

afterEach(() => {
nowMock.mockRestore();
});

it("should extend the expiration random time(~15 mins) from now", () => {
const anyDate: Date = "any date" as unknown as Date;
(Math.random as jest.Mock).mockReturnValue(0.5);
expect(getExtendedInstanceMetadataCredentials({ ...staticSecret, expiration: anyDate })).toEqual({
...staticSecret,
originalExpiration: anyDate,
expiration: new Date("2022-02-22T00:17:30Z"),
});
expect(Math.random).toBeCalledTimes(1);
});

it("should print warning message when extending the credentials", () => {
const anyDate: Date = "any date" as unknown as Date;
getExtendedInstanceMetadataCredentials({ ...staticSecret, expiration: anyDate });
// TODO: fill the doc link
expect(console.warn).toBeCalledWith(expect.stringContaining("Attempting credential expiration extension"));
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import { InstanceMetadataCredentials } from "../types";

const STATIC_STABILITY_REFRESH_INTERVAL_SECONDS = 15 * 60;
const STATIC_STABILITY_REFRESH_INTERVAL_JITTER_WINDOW_SECONDS = 5 * 60;
// TODO
const STATIC_STABILITY_DOC_URL = "https://docs.aws.amazon.com/sdkref/latest/guide/feature-static-credentials.html";

export const getExtendedInstanceMetadataCredentials = (
credentials: InstanceMetadataCredentials
): InstanceMetadataCredentials => {
const refreshInterval =
STATIC_STABILITY_REFRESH_INTERVAL_SECONDS +
Math.floor(Math.random() * STATIC_STABILITY_REFRESH_INTERVAL_JITTER_WINDOW_SECONDS);
const newExpiration = new Date(Date.now() + refreshInterval * 1000);
// ToDo: Call warn function on logger from configuration
console.warn(
"Attempting credential expiration extension due to a credential service availability issue. A refresh of these " +
"credentials will be attempted after ${new Date(newExpiration)}.\nFor more information, please visit: " +
STATIC_STABILITY_DOC_URL
);
const originalExpiration = credentials.originalExpiration ?? credentials.expiration;
return {
...credentials,
...(originalExpiration ? { originalExpiration } : {}),
expiration: newExpiration,
};
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import { getExtendedInstanceMetadataCredentials } from "./getExtendedInstanceMetadataCredentials";
import { staticStabilityProvider } from "./staticStabilityProvider";

jest.mock("./getExtendedInstanceMetadataCredentials");

describe("staticStabilityProvider", () => {
const ONE_HOUR_IN_FUTURE = new Date(Date.now() + 60 * 60 * 1000);
const mockCreds = {
accessKeyId: "key",
secretAccessKey: "secret",
sessionToken: "settion",
expiration: ONE_HOUR_IN_FUTURE,
};

beforeEach(() => {
(getExtendedInstanceMetadataCredentials as jest.Mock).mockImplementation(
(() => {
let extensionCount = 0;
return (input) => {
extensionCount++;
return {
...input,
expiration: `Extending expiration count: ${extensionCount}`,
};
};
})()
);
jest.spyOn(global.console, "warn").mockImplementation(() => {});
});

afterEach(() => {
jest.resetAllMocks();
});

it("should refresh credentials if provider is functional", async () => {
const provider = jest.fn();
const stableProvider = staticStabilityProvider(provider);
const repeat = 3;
for (let i = 0; i < repeat; i++) {
const newCreds = { ...mockCreds, accessKeyId: String(i + 1) };
provider.mockReset().mockResolvedValue(newCreds);
expect(await stableProvider()).toEqual(newCreds);
}
});

it("should throw if cannot load credentials at 1st load", async () => {
const provider = jest.fn().mockRejectedValue("Error");
try {
await staticStabilityProvider(provider)();
fail("This provider should throw");
} catch (e) {
expect(getExtendedInstanceMetadataCredentials).not.toBeCalled();
expect(provider).toBeCalledTimes(1);
expect(e).toEqual("Error");
}
});

it("should extend expired credentials if refresh fails", async () => {
const provider = jest.fn().mockResolvedValueOnce(mockCreds).mockRejectedValue("Error");
const stableProvider = staticStabilityProvider(provider);
expect(await stableProvider()).toEqual(mockCreds);
const repeat = 3;
for (let i = 0; i < repeat; i++) {
const newCreds = await stableProvider();
expect(newCreds).toMatchObject({ ...mockCreds, expiration: expect.stringContaining(`count: ${i + 1}`) });
expect(console.warn).toHaveBeenLastCalledWith(
expect.stringContaining("Credential renew failed:"),
expect.anything()
);
}
expect(getExtendedInstanceMetadataCredentials).toBeCalledTimes(repeat);
expect(console.warn).toBeCalledTimes(repeat);
});

it("should extend expired credentials if loaded expired credentials", async () => {
const ONE_HOUR_AGO = new Date(Date.now() - 60 * 60 * 1000);
const provider = jest.fn().mockResolvedValue({ ...mockCreds, expiration: ONE_HOUR_AGO });
const stableProvider = staticStabilityProvider(provider);
const repeat = 3;
for (let i = 0; i < repeat; i++) {
const newCreds = await stableProvider();
expect(newCreds).toMatchObject({ ...mockCreds, expiration: expect.stringContaining(`count: ${i + 1}`) });
}
expect(getExtendedInstanceMetadataCredentials).toBeCalledTimes(repeat);
expect(console.warn).not.toBeCalled();
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import { Credentials, Provider } from "@aws-sdk/types";

import { InstanceMetadataCredentials } from "../types";
import { getExtendedInstanceMetadataCredentials } from "./getExtendedInstanceMetadataCredentials";

/**
* IMDS credential supports static stability feature. When used, the expiration
* of recently issued credentials is extended. The server side allows using
* the recently expired credentials. This mitigates impact when clients using
* refreshable credentials are unable to retrieve updates.
*
* @param provider Credential provider
* @returns A credential provider that supports static stability
*/
export const staticStabilityProvider = (
provider: Provider<InstanceMetadataCredentials>
): Provider<InstanceMetadataCredentials> => {
let pastCredentials: InstanceMetadataCredentials;
return async () => {
let credentials: InstanceMetadataCredentials;
try {
credentials = await provider();
if (credentials.expiration && credentials.expiration.getTime() < Date.now()) {
credentials = getExtendedInstanceMetadataCredentials(credentials);
}
} catch (e) {
if (pastCredentials) {
// ToDo: Call warn function on logger from configuration
console.warn("Credential renew failed: ", e);
credentials = getExtendedInstanceMetadataCredentials(pastCredentials);
} else {
throw e;
}
}
pastCredentials = credentials;
return credentials;
};
};