Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(NODE-5940): cache the AWS credentials provider in the MONGODB-AWS auth logic #4000

Merged
merged 2 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/cmap/auth/auth_provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ export class AuthContext {
}
}

/**
* Provider used during authentication.
* @internal
*/
export abstract class AuthProvider {
/**
* Prepare the handshake document before the initial handshake.
Expand Down
123 changes: 62 additions & 61 deletions src/cmap/auth/mongodb_aws.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import * as crypto from 'crypto';
import * as process from 'process';
import { promisify } from 'util';

import type { Binary, BSONSerializeOptions } from '../../bson';
import * as BSON from '../../bson';
import { aws4, getAwsCredentialProvider } from '../../deps';
import { aws4, type AWSCredentials, getAwsCredentialProvider } from '../../deps';
import {
MongoAWSError,
MongoCompatibilityError,
MongoMissingCredentialsError,
MongoRuntimeError
} from '../../error';
import { ByteUtils, maxWireVersion, ns, request } from '../../utils';
import { ByteUtils, maxWireVersion, ns, randomBytes, request } from '../../utils';
import { type AuthContext, AuthProvider } from './auth_provider';
import { MongoCredentials } from './mongo_credentials';
import { AuthMechanism } from './providers';
Expand Down Expand Up @@ -57,12 +55,40 @@ interface AWSSaslContinuePayload {
}

export class MongoDBAWS extends AuthProvider {
static credentialProvider: ReturnType<typeof getAwsCredentialProvider> | null = null;
randomBytesAsync: (size: number) => Promise<Buffer>;
static credentialProvider: ReturnType<typeof getAwsCredentialProvider>;
provider?: () => Promise<AWSCredentials>;

constructor() {
super();
this.randomBytesAsync = promisify(crypto.randomBytes);
MongoDBAWS.credentialProvider ??= getAwsCredentialProvider();

let { AWS_STS_REGIONAL_ENDPOINTS = '', AWS_REGION = '' } = process.env;
AWS_STS_REGIONAL_ENDPOINTS = AWS_STS_REGIONAL_ENDPOINTS.toLowerCase();
AWS_REGION = AWS_REGION.toLowerCase();

/** The option setting should work only for users who have explicit settings in their environment, the driver should not encode "defaults" */
const awsRegionSettingsExist =
AWS_REGION.length !== 0 && AWS_STS_REGIONAL_ENDPOINTS.length !== 0;

/**
* If AWS_STS_REGIONAL_ENDPOINTS is set to regional, users are opting into the new behavior of respecting the region settings
*
* If AWS_STS_REGIONAL_ENDPOINTS is set to legacy, then "old" regions need to keep using the global setting.
* Technically the SDK gets this wrong, it reaches out to 'sts.us-east-1.amazonaws.com' when it should be 'sts.amazonaws.com'.
* That is not our bug to fix here. We leave that up to the SDK.
*/
const useRegionalSts =
AWS_STS_REGIONAL_ENDPOINTS === 'regional' ||
(AWS_STS_REGIONAL_ENDPOINTS === 'legacy' && !LEGACY_REGIONS.has(AWS_REGION));

if ('fromNodeProviderChain' in MongoDBAWS.credentialProvider) {
this.provider =
awsRegionSettingsExist && useRegionalSts
? MongoDBAWS.credentialProvider.fromNodeProviderChain({
clientConfig: { region: AWS_REGION }
})
: MongoDBAWS.credentialProvider.fromNodeProviderChain();
}
}

override async auth(authContext: AuthContext): Promise<void> {
Expand All @@ -83,7 +109,7 @@ export class MongoDBAWS extends AuthProvider {
}

if (!authContext.credentials.username) {
authContext.credentials = await makeTempCredentials(authContext.credentials);
authContext.credentials = await makeTempCredentials(authContext.credentials, this.provider);
}

const { credentials } = authContext;
Expand All @@ -101,7 +127,7 @@ export class MongoDBAWS extends AuthProvider {
: undefined;

const db = credentials.source;
const nonce = await this.randomBytesAsync(32);
const nonce = await randomBytes(32);

const saslStart = {
saslStart: 1,
Expand Down Expand Up @@ -181,7 +207,10 @@ interface AWSTempCredentials {
Expiration?: Date;
}

async function makeTempCredentials(credentials: MongoCredentials): Promise<MongoCredentials> {
async function makeTempCredentials(
credentials: MongoCredentials,
provider?: () => Promise<AWSCredentials>
): Promise<MongoCredentials> {
function makeMongoCredentialsFromAWSTemp(creds: AWSTempCredentials) {
if (!creds.AccessKeyId || !creds.SecretAccessKey || !creds.Token) {
throw new MongoMissingCredentialsError('Could not obtain temporary MONGODB-AWS credentials');
Expand All @@ -198,11 +227,31 @@ async function makeTempCredentials(credentials: MongoCredentials): Promise<Mongo
});
}

MongoDBAWS.credentialProvider ??= getAwsCredentialProvider();

// Check if the AWS credential provider from the SDK is present. If not,
// use the old method.
if ('kModuleError' in MongoDBAWS.credentialProvider) {
if (provider && !('kModuleError' in MongoDBAWS.credentialProvider)) {
/*
* Creates a credential provider that will attempt to find credentials from the
* following sources (listed in order of precedence):
*
* - Environment variables exposed via process.env
* - SSO credentials from token cache
* - Web identity token credentials
* - Shared credentials and config ini files
* - The EC2/ECS Instance Metadata Service
*/
try {
const creds = await provider();
return makeMongoCredentialsFromAWSTemp({
AccessKeyId: creds.accessKeyId,
SecretAccessKey: creds.secretAccessKey,
Token: creds.sessionToken,
Expiration: creds.expiration
});
} catch (error) {
throw new MongoAWSError(error.message);
}
} else {
// If the environment variable AWS_CONTAINER_CREDENTIALS_RELATIVE_URI
// is set then drivers MUST assume that it was set by an AWS ECS agent
if (process.env.AWS_CONTAINER_CREDENTIALS_RELATIVE_URI) {
Expand Down Expand Up @@ -232,54 +281,6 @@ async function makeTempCredentials(credentials: MongoCredentials): Promise<Mongo
});

return makeMongoCredentialsFromAWSTemp(creds);
} else {
let { AWS_STS_REGIONAL_ENDPOINTS = '', AWS_REGION = '' } = process.env;
AWS_STS_REGIONAL_ENDPOINTS = AWS_STS_REGIONAL_ENDPOINTS.toLowerCase();
AWS_REGION = AWS_REGION.toLowerCase();

/** The option setting should work only for users who have explicit settings in their environment, the driver should not encode "defaults" */
const awsRegionSettingsExist =
AWS_REGION.length !== 0 && AWS_STS_REGIONAL_ENDPOINTS.length !== 0;

/**
* If AWS_STS_REGIONAL_ENDPOINTS is set to regional, users are opting into the new behavior of respecting the region settings
*
* If AWS_STS_REGIONAL_ENDPOINTS is set to legacy, then "old" regions need to keep using the global setting.
* Technically the SDK gets this wrong, it reaches out to 'sts.us-east-1.amazonaws.com' when it should be 'sts.amazonaws.com'.
* That is not our bug to fix here. We leave that up to the SDK.
*/
const useRegionalSts =
AWS_STS_REGIONAL_ENDPOINTS === 'regional' ||
(AWS_STS_REGIONAL_ENDPOINTS === 'legacy' && !LEGACY_REGIONS.has(AWS_REGION));

const provider =
awsRegionSettingsExist && useRegionalSts
? MongoDBAWS.credentialProvider.fromNodeProviderChain({
clientConfig: { region: AWS_REGION }
})
: MongoDBAWS.credentialProvider.fromNodeProviderChain();

/*
* Creates a credential provider that will attempt to find credentials from the
* following sources (listed in order of precedence):
*
* - Environment variables exposed via process.env
* - SSO credentials from token cache
* - Web identity token credentials
* - Shared credentials and config ini files
* - The EC2/ECS Instance Metadata Service
*/
try {
const creds = await provider();
return makeMongoCredentialsFromAWSTemp({
AccessKeyId: creds.accessKeyId,
SecretAccessKey: creds.secretAccessKey,
Token: creds.sessionToken,
Expiration: creds.expiration
});
} catch (error) {
throw new MongoAWSError(error.message);
}
}
}

Expand Down
7 changes: 2 additions & 5 deletions src/cmap/auth/scram.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import * as crypto from 'crypto';
import { promisify } from 'util';

import { Binary, type Document } from '../../bson';
import { saslprep } from '../../deps';
Expand All @@ -8,7 +7,7 @@ import {
MongoMissingCredentialsError,
MongoRuntimeError
} from '../../error';
import { emitWarning, ns } from '../../utils';
import { emitWarning, ns, randomBytes } from '../../utils';
import type { HandshakeDocument } from '../connect';
import { type AuthContext, AuthProvider } from './auth_provider';
import type { MongoCredentials } from './mongo_credentials';
Expand All @@ -18,11 +17,9 @@ type CryptoMethod = 'sha1' | 'sha256';

class ScramSHA extends AuthProvider {
cryptoMethod: CryptoMethod;
randomBytesAsync: (size: number) => Promise<Buffer>;
constructor(cryptoMethod: CryptoMethod) {
super();
this.cryptoMethod = cryptoMethod || 'sha1';
this.randomBytesAsync = promisify(crypto.randomBytes);
}

override async prepare(
Expand All @@ -41,7 +38,7 @@ class ScramSHA extends AuthProvider {
emitWarning('Warning: no saslprep library specified. Passwords will not be sanitized');
}

const nonce = await this.randomBytesAsync(24);
const nonce = await randomBytes(24);
// store the nonce for later use
authContext.nonce = nonce;

Expand Down
33 changes: 11 additions & 22 deletions src/cmap/connect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,8 @@ import {
needsRetryableWriteLabel
} from '../error';
import { type Callback, HostAddress, ns } from '../utils';
import { AuthContext, type AuthProvider } from './auth/auth_provider';
import { GSSAPI } from './auth/gssapi';
import { MongoCR } from './auth/mongocr';
import { MongoDBAWS } from './auth/mongodb_aws';
import { Plain } from './auth/plain';
import { AuthContext } from './auth/auth_provider';
import { AuthMechanism } from './auth/providers';
import { ScramSHA1, ScramSHA256 } from './auth/scram';
import { X509 } from './auth/x509';
import {
type CommandOptions,
Connection,
Expand All @@ -39,17 +33,6 @@ import {
MIN_SUPPORTED_WIRE_VERSION
} from './wire_protocol/constants';

/** @internal */
export const AUTH_PROVIDERS = new Map<AuthMechanism | string, AuthProvider>([
[AuthMechanism.MONGODB_AWS, new MongoDBAWS()],
[AuthMechanism.MONGODB_CR, new MongoCR()],
[AuthMechanism.MONGODB_GSSAPI, new GSSAPI()],
[AuthMechanism.MONGODB_PLAIN, new Plain()],
[AuthMechanism.MONGODB_SCRAM_SHA1, new ScramSHA1()],
[AuthMechanism.MONGODB_SCRAM_SHA256, new ScramSHA256()],
[AuthMechanism.MONGODB_X509, new X509()]
]);

/** @public */
export type Stream = Socket | TLSSocket;

Expand Down Expand Up @@ -110,7 +93,7 @@ async function performInitialHandshake(
if (credentials) {
if (
!(credentials.mechanism === AuthMechanism.MONGODB_DEFAULT) &&
!AUTH_PROVIDERS.get(credentials.mechanism)
!options.authProviders.getOrCreateProvider(credentials.mechanism)
) {
throw new MongoInvalidArgumentError(`AuthMechanism '${credentials.mechanism}' not supported`);
}
Expand Down Expand Up @@ -165,7 +148,7 @@ async function performInitialHandshake(
authContext.response = response;

const resolvedCredentials = credentials.resolveAuthMechanism(response);
const provider = AUTH_PROVIDERS.get(resolvedCredentials.mechanism);
const provider = options.authProviders.getOrCreateProvider(resolvedCredentials.mechanism);
if (!provider) {
throw new MongoInvalidArgumentError(
`No AuthProvider for ${resolvedCredentials.mechanism} defined.`
Expand All @@ -186,6 +169,10 @@ async function performInitialHandshake(
}
}

/**
* HandshakeDocument used during authentication.
* @internal
*/
export interface HandshakeDocument extends Document {
/**
* @deprecated Use hello instead
Expand Down Expand Up @@ -227,7 +214,9 @@ export async function prepareHandshakeDocument(
if (credentials.mechanism === AuthMechanism.MONGODB_DEFAULT && credentials.username) {
handshakeDoc.saslSupportedMechs = `${credentials.source}.${credentials.username}`;

const provider = AUTH_PROVIDERS.get(AuthMechanism.MONGODB_SCRAM_SHA256);
const provider = authContext.options.authProviders.getOrCreateProvider(
AuthMechanism.MONGODB_SCRAM_SHA256
);
if (!provider) {
// This auth mechanism is always present.
throw new MongoInvalidArgumentError(
Expand All @@ -236,7 +225,7 @@ export async function prepareHandshakeDocument(
}
return provider.prepare(handshakeDoc, authContext);
}
const provider = AUTH_PROVIDERS.get(credentials.mechanism);
const provider = authContext.options.authProviders.getOrCreateProvider(credentials.mechanism);
if (!provider) {
throw new MongoInvalidArgumentError(`No AuthProvider for ${credentials.mechanism} defined.`);
}
Expand Down
3 changes: 3 additions & 0 deletions src/cmap/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {
MongoWriteConcernError
} from '../error';
import type { ServerApi, SupportedNodeConnectionOptions } from '../mongo_client';
import { type MongoClientAuthProviders } from '../mongo_client_auth_providers';
import { type CancellationToken, TypedEventEmitter } from '../mongo_types';
import type { ReadPreferenceLike } from '../read_preference';
import { applySession, type ClientSession, updateSessionFromResponse } from '../sessions';
Expand Down Expand Up @@ -120,6 +121,8 @@ export interface ConnectionOptions
/** @internal */
connectionType?: typeof Connection;
credentials?: MongoCredentials;
/** @internal */
authProviders: MongoClientAuthProviders;
connectTimeoutMS?: number;
tls: boolean;
/** @deprecated - Will not be able to turn off in the future. */
Expand Down
9 changes: 6 additions & 3 deletions src/cmap/connection_pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import {
import { CancellationToken, TypedEventEmitter } from '../mongo_types';
import type { Server } from '../sdam/server';
import { type Callback, eachAsync, List, makeCounter } from '../utils';
import { AUTH_PROVIDERS, connect } from './connect';
import { connect } from './connect';
import { Connection, type ConnectionEvents, type ConnectionOptions } from './connection';
import {
ConnectionCheckedInEvent,
Expand Down Expand Up @@ -620,7 +620,9 @@ export class ConnectionPool extends TypedEventEmitter<ConnectionPoolEvents> {
);
}
const resolvedCredentials = credentials.resolveAuthMechanism(connection.hello || undefined);
const provider = AUTH_PROVIDERS.get(resolvedCredentials.mechanism);
const provider = this[kServer].topology.client.s.authProviders.getOrCreateProvider(
resolvedCredentials.mechanism
);
if (!provider) {
return callback(
new MongoMissingCredentialsError(
Expand Down Expand Up @@ -697,7 +699,8 @@ export class ConnectionPool extends TypedEventEmitter<ConnectionPoolEvents> {
...this.options,
id: this[kConnectionCounter].next().value,
generation: this[kGeneration],
cancellationToken: this[kCancellationToken]
cancellationToken: this[kCancellationToken],
authProviders: this[kServer].topology.client.s.authProviders
};

this[kPending]++;
Expand Down
4 changes: 3 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ export type {
ResumeToken,
UpdateDescription
} from './change_stream';
export type { AuthContext } from './cmap/auth/auth_provider';
export type { AuthContext, AuthProvider } from './cmap/auth/auth_provider';
export type {
AuthMechanismProperties,
MongoCredentials,
Expand All @@ -217,6 +217,7 @@ export type {
Response,
WriteProtocolMessageType
} from './cmap/commands';
export type { HandshakeDocument } from './cmap/connect';
export type { LEGAL_TCP_SOCKET_OPTIONS, LEGAL_TLS_SOCKET_OPTIONS, Stream } from './cmap/connect';
export type {
CommandOptions,
Expand Down Expand Up @@ -304,6 +305,7 @@ export type {
SupportedTLSSocketOptions,
WithSessionCallback
} from './mongo_client';
export { MongoClientAuthProviders } from './mongo_client_auth_providers';
export type {
Log,
LogConvertible,
Expand Down
Loading