Skip to content

Commit e0a37e5

Browse files
feat(NODE-5939): Implement 6.x: cache the AWS credentials provider in the MONGODB-AWS auth logic (#3991)
Co-authored-by: Durran Jordan <[email protected]>
1 parent 38742c2 commit e0a37e5

15 files changed

+229
-102
lines changed

src/cmap/auth/auth_provider.ts

+4
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ export class AuthContext {
3434
}
3535
}
3636

37+
/**
38+
* Provider used during authentication.
39+
* @internal
40+
*/
3741
export abstract class AuthProvider {
3842
/**
3943
* Prepare the handshake document before the initial handshake.

src/cmap/auth/mongodb_aws.ts

+60-55
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import { promisify } from 'util';
44

55
import type { Binary, BSONSerializeOptions } from '../../bson';
66
import * as BSON from '../../bson';
7-
import { aws4, getAwsCredentialProvider } from '../../deps';
7+
import { aws4, type AWSCredentials, getAwsCredentialProvider } from '../../deps';
88
import {
99
MongoAWSError,
1010
MongoCompatibilityError,
@@ -57,12 +57,42 @@ interface AWSSaslContinuePayload {
5757
}
5858

5959
export class MongoDBAWS extends AuthProvider {
60-
static credentialProvider: ReturnType<typeof getAwsCredentialProvider> | null = null;
60+
static credentialProvider: ReturnType<typeof getAwsCredentialProvider>;
61+
provider?: () => Promise<AWSCredentials>;
6162
randomBytesAsync: (size: number) => Promise<Buffer>;
6263

6364
constructor() {
6465
super();
6566
this.randomBytesAsync = promisify(crypto.randomBytes);
67+
MongoDBAWS.credentialProvider ??= getAwsCredentialProvider();
68+
69+
let { AWS_STS_REGIONAL_ENDPOINTS = '', AWS_REGION = '' } = process.env;
70+
AWS_STS_REGIONAL_ENDPOINTS = AWS_STS_REGIONAL_ENDPOINTS.toLowerCase();
71+
AWS_REGION = AWS_REGION.toLowerCase();
72+
73+
/** The option setting should work only for users who have explicit settings in their environment, the driver should not encode "defaults" */
74+
const awsRegionSettingsExist =
75+
AWS_REGION.length !== 0 && AWS_STS_REGIONAL_ENDPOINTS.length !== 0;
76+
77+
/**
78+
* If AWS_STS_REGIONAL_ENDPOINTS is set to regional, users are opting into the new behavior of respecting the region settings
79+
*
80+
* If AWS_STS_REGIONAL_ENDPOINTS is set to legacy, then "old" regions need to keep using the global setting.
81+
* Technically the SDK gets this wrong, it reaches out to 'sts.us-east-1.amazonaws.com' when it should be 'sts.amazonaws.com'.
82+
* That is not our bug to fix here. We leave that up to the SDK.
83+
*/
84+
const useRegionalSts =
85+
AWS_STS_REGIONAL_ENDPOINTS === 'regional' ||
86+
(AWS_STS_REGIONAL_ENDPOINTS === 'legacy' && !LEGACY_REGIONS.has(AWS_REGION));
87+
88+
if ('fromNodeProviderChain' in MongoDBAWS.credentialProvider) {
89+
this.provider =
90+
awsRegionSettingsExist && useRegionalSts
91+
? MongoDBAWS.credentialProvider.fromNodeProviderChain({
92+
clientConfig: { region: AWS_REGION }
93+
})
94+
: MongoDBAWS.credentialProvider.fromNodeProviderChain();
95+
}
6696
}
6797

6898
override async auth(authContext: AuthContext): Promise<void> {
@@ -83,7 +113,7 @@ export class MongoDBAWS extends AuthProvider {
83113
}
84114

85115
if (!authContext.credentials.username) {
86-
authContext.credentials = await makeTempCredentials(authContext.credentials);
116+
authContext.credentials = await makeTempCredentials(authContext.credentials, this.provider);
87117
}
88118

89119
const { credentials } = authContext;
@@ -181,7 +211,10 @@ interface AWSTempCredentials {
181211
Expiration?: Date;
182212
}
183213

184-
async function makeTempCredentials(credentials: MongoCredentials): Promise<MongoCredentials> {
214+
async function makeTempCredentials(
215+
credentials: MongoCredentials,
216+
provider?: () => Promise<AWSCredentials>
217+
): Promise<MongoCredentials> {
185218
function makeMongoCredentialsFromAWSTemp(creds: AWSTempCredentials) {
186219
if (!creds.AccessKeyId || !creds.SecretAccessKey || !creds.Token) {
187220
throw new MongoMissingCredentialsError('Could not obtain temporary MONGODB-AWS credentials');
@@ -198,11 +231,31 @@ async function makeTempCredentials(credentials: MongoCredentials): Promise<Mongo
198231
});
199232
}
200233

201-
MongoDBAWS.credentialProvider ??= getAwsCredentialProvider();
202-
203234
// Check if the AWS credential provider from the SDK is present. If not,
204235
// use the old method.
205-
if ('kModuleError' in MongoDBAWS.credentialProvider) {
236+
if (provider && !('kModuleError' in MongoDBAWS.credentialProvider)) {
237+
/*
238+
* Creates a credential provider that will attempt to find credentials from the
239+
* following sources (listed in order of precedence):
240+
*
241+
* - Environment variables exposed via process.env
242+
* - SSO credentials from token cache
243+
* - Web identity token credentials
244+
* - Shared credentials and config ini files
245+
* - The EC2/ECS Instance Metadata Service
246+
*/
247+
try {
248+
const creds = await provider();
249+
return makeMongoCredentialsFromAWSTemp({
250+
AccessKeyId: creds.accessKeyId,
251+
SecretAccessKey: creds.secretAccessKey,
252+
Token: creds.sessionToken,
253+
Expiration: creds.expiration
254+
});
255+
} catch (error) {
256+
throw new MongoAWSError(error.message);
257+
}
258+
} else {
206259
// If the environment variable AWS_CONTAINER_CREDENTIALS_RELATIVE_URI
207260
// is set then drivers MUST assume that it was set by an AWS ECS agent
208261
if (process.env.AWS_CONTAINER_CREDENTIALS_RELATIVE_URI) {
@@ -232,54 +285,6 @@ async function makeTempCredentials(credentials: MongoCredentials): Promise<Mongo
232285
});
233286

234287
return makeMongoCredentialsFromAWSTemp(creds);
235-
} else {
236-
let { AWS_STS_REGIONAL_ENDPOINTS = '', AWS_REGION = '' } = process.env;
237-
AWS_STS_REGIONAL_ENDPOINTS = AWS_STS_REGIONAL_ENDPOINTS.toLowerCase();
238-
AWS_REGION = AWS_REGION.toLowerCase();
239-
240-
/** The option setting should work only for users who have explicit settings in their environment, the driver should not encode "defaults" */
241-
const awsRegionSettingsExist =
242-
AWS_REGION.length !== 0 && AWS_STS_REGIONAL_ENDPOINTS.length !== 0;
243-
244-
/**
245-
* If AWS_STS_REGIONAL_ENDPOINTS is set to regional, users are opting into the new behavior of respecting the region settings
246-
*
247-
* If AWS_STS_REGIONAL_ENDPOINTS is set to legacy, then "old" regions need to keep using the global setting.
248-
* Technically the SDK gets this wrong, it reaches out to 'sts.us-east-1.amazonaws.com' when it should be 'sts.amazonaws.com'.
249-
* That is not our bug to fix here. We leave that up to the SDK.
250-
*/
251-
const useRegionalSts =
252-
AWS_STS_REGIONAL_ENDPOINTS === 'regional' ||
253-
(AWS_STS_REGIONAL_ENDPOINTS === 'legacy' && !LEGACY_REGIONS.has(AWS_REGION));
254-
255-
const provider =
256-
awsRegionSettingsExist && useRegionalSts
257-
? MongoDBAWS.credentialProvider.fromNodeProviderChain({
258-
clientConfig: { region: AWS_REGION }
259-
})
260-
: MongoDBAWS.credentialProvider.fromNodeProviderChain();
261-
262-
/*
263-
* Creates a credential provider that will attempt to find credentials from the
264-
* following sources (listed in order of precedence):
265-
*
266-
* - Environment variables exposed via process.env
267-
* - SSO credentials from token cache
268-
* - Web identity token credentials
269-
* - Shared credentials and config ini files
270-
* - The EC2/ECS Instance Metadata Service
271-
*/
272-
try {
273-
const creds = await provider();
274-
return makeMongoCredentialsFromAWSTemp({
275-
AccessKeyId: creds.accessKeyId,
276-
SecretAccessKey: creds.secretAccessKey,
277-
Token: creds.sessionToken,
278-
Expiration: creds.expiration
279-
});
280-
} catch (error) {
281-
throw new MongoAWSError(error.message);
282-
}
283288
}
284289
}
285290

src/cmap/connect.ts

+13-26
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,10 @@ import {
1616
MongoRuntimeError,
1717
needsRetryableWriteLabel
1818
} from '../error';
19+
import { type MongoClientAuthProviders } from '../mongo_client_auth_providers';
1920
import { HostAddress, ns, promiseWithResolvers } from '../utils';
20-
import { AuthContext, type AuthProvider } from './auth/auth_provider';
21-
import { GSSAPI } from './auth/gssapi';
22-
import { MongoCR } from './auth/mongocr';
23-
import { MongoDBAWS } from './auth/mongodb_aws';
24-
import { MongoDBOIDC } from './auth/mongodb_oidc';
25-
import { Plain } from './auth/plain';
21+
import { AuthContext } from './auth/auth_provider';
2622
import { AuthMechanism } from './auth/providers';
27-
import { ScramSHA1, ScramSHA256 } from './auth/scram';
28-
import { X509 } from './auth/x509';
2923
import {
3024
type CommandOptions,
3125
Connection,
@@ -40,18 +34,6 @@ import {
4034
MIN_SUPPORTED_WIRE_VERSION
4135
} from './wire_protocol/constants';
4236

43-
/** @internal */
44-
export const AUTH_PROVIDERS = new Map<AuthMechanism | string, AuthProvider>([
45-
[AuthMechanism.MONGODB_AWS, new MongoDBAWS()],
46-
[AuthMechanism.MONGODB_CR, new MongoCR()],
47-
[AuthMechanism.MONGODB_GSSAPI, new GSSAPI()],
48-
[AuthMechanism.MONGODB_OIDC, new MongoDBOIDC()],
49-
[AuthMechanism.MONGODB_PLAIN, new Plain()],
50-
[AuthMechanism.MONGODB_SCRAM_SHA1, new ScramSHA1()],
51-
[AuthMechanism.MONGODB_SCRAM_SHA256, new ScramSHA256()],
52-
[AuthMechanism.MONGODB_X509, new X509()]
53-
]);
54-
5537
/** @public */
5638
export type Stream = Socket | TLSSocket;
5739

@@ -111,7 +93,7 @@ export async function performInitialHandshake(
11193
if (credentials) {
11294
if (
11395
!(credentials.mechanism === AuthMechanism.MONGODB_DEFAULT) &&
114-
!AUTH_PROVIDERS.get(credentials.mechanism)
96+
!options.authProviders.getOrCreateProvider(credentials.mechanism)
11597
) {
11698
throw new MongoInvalidArgumentError(`AuthMechanism '${credentials.mechanism}' not supported`);
11799
}
@@ -120,7 +102,7 @@ export async function performInitialHandshake(
120102
const authContext = new AuthContext(conn, credentials, options);
121103
conn.authContext = authContext;
122104

123-
const handshakeDoc = await prepareHandshakeDocument(authContext);
105+
const handshakeDoc = await prepareHandshakeDocument(authContext, options.authProviders);
124106

125107
// @ts-expect-error: TODO(NODE-5141): The options need to be filtered properly, Connection options differ from Command options
126108
const handshakeOptions: CommandOptions = { ...options };
@@ -166,7 +148,7 @@ export async function performInitialHandshake(
166148
authContext.response = response;
167149

168150
const resolvedCredentials = credentials.resolveAuthMechanism(response);
169-
const provider = AUTH_PROVIDERS.get(resolvedCredentials.mechanism);
151+
const provider = options.authProviders.getOrCreateProvider(resolvedCredentials.mechanism);
170152
if (!provider) {
171153
throw new MongoInvalidArgumentError(
172154
`No AuthProvider for ${resolvedCredentials.mechanism} defined.`
@@ -191,6 +173,10 @@ export async function performInitialHandshake(
191173
conn.established = true;
192174
}
193175

176+
/**
177+
* HandshakeDocument used during authentication.
178+
* @internal
179+
*/
194180
export interface HandshakeDocument extends Document {
195181
/**
196182
* @deprecated Use hello instead
@@ -210,7 +196,8 @@ export interface HandshakeDocument extends Document {
210196
* This function is only exposed for testing purposes.
211197
*/
212198
export async function prepareHandshakeDocument(
213-
authContext: AuthContext
199+
authContext: AuthContext,
200+
authProviders: MongoClientAuthProviders
214201
): Promise<HandshakeDocument> {
215202
const options = authContext.options;
216203
const compressors = options.compressors ? options.compressors : [];
@@ -232,7 +219,7 @@ export async function prepareHandshakeDocument(
232219
if (credentials.mechanism === AuthMechanism.MONGODB_DEFAULT && credentials.username) {
233220
handshakeDoc.saslSupportedMechs = `${credentials.source}.${credentials.username}`;
234221

235-
const provider = AUTH_PROVIDERS.get(AuthMechanism.MONGODB_SCRAM_SHA256);
222+
const provider = authProviders.getOrCreateProvider(AuthMechanism.MONGODB_SCRAM_SHA256);
236223
if (!provider) {
237224
// This auth mechanism is always present.
238225
throw new MongoInvalidArgumentError(
@@ -241,7 +228,7 @@ export async function prepareHandshakeDocument(
241228
}
242229
return provider.prepare(handshakeDoc, authContext);
243230
}
244-
const provider = AUTH_PROVIDERS.get(credentials.mechanism);
231+
const provider = authProviders.getOrCreateProvider(credentials.mechanism);
245232
if (!provider) {
246233
throw new MongoInvalidArgumentError(`No AuthProvider for ${credentials.mechanism} defined.`);
247234
}

src/cmap/connection.ts

+3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import {
2424
MongoWriteConcernError
2525
} from '../error';
2626
import type { ServerApi, SupportedNodeConnectionOptions } from '../mongo_client';
27+
import { type MongoClientAuthProviders } from '../mongo_client_auth_providers';
2728
import { MongoLoggableComponent, type MongoLogger, SeverityLevel } from '../mongo_logger';
2829
import { type CancellationToken, TypedEventEmitter } from '../mongo_types';
2930
import type { ReadPreferenceLike } from '../read_preference';
@@ -109,6 +110,8 @@ export interface ConnectionOptions
109110
/** @internal */
110111
connectionType?: any;
111112
credentials?: MongoCredentials;
113+
/** @internal */
114+
authProviders: MongoClientAuthProviders;
112115
connectTimeoutMS?: number;
113116
tls: boolean;
114117
noDelay?: boolean;

src/cmap/connection_pool.ts

+6-3
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import {
2828
import { CancellationToken, TypedEventEmitter } from '../mongo_types';
2929
import type { Server } from '../sdam/server';
3030
import { type Callback, eachAsync, List, makeCounter, TimeoutController } from '../utils';
31-
import { AUTH_PROVIDERS, connect } from './connect';
31+
import { connect } from './connect';
3232
import { Connection, type ConnectionEvents, type ConnectionOptions } from './connection';
3333
import {
3434
ConnectionCheckedInEvent,
@@ -622,7 +622,9 @@ export class ConnectionPool extends TypedEventEmitter<ConnectionPoolEvents> {
622622
);
623623
}
624624
const resolvedCredentials = credentials.resolveAuthMechanism(connection.hello);
625-
const provider = AUTH_PROVIDERS.get(resolvedCredentials.mechanism);
625+
const provider = this[kServer].topology.client.s.authProviders.getOrCreateProvider(
626+
resolvedCredentials.mechanism
627+
);
626628
if (!provider) {
627629
return callback(
628630
new MongoMissingCredentialsError(
@@ -700,7 +702,8 @@ export class ConnectionPool extends TypedEventEmitter<ConnectionPoolEvents> {
700702
id: this[kConnectionCounter].next().value,
701703
generation: this[kGeneration],
702704
cancellationToken: this[kCancellationToken],
703-
mongoLogger: this.mongoLogger
705+
mongoLogger: this.mongoLogger,
706+
authProviders: this[kServer].topology.client.s.authProviders
704707
};
705708

706709
this[kPending]++;

src/index.ts

+3-1
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ export type {
243243
CSFLEKMSTlsOptions,
244244
StateMachineExecutable
245245
} from './client-side-encryption/state_machine';
246-
export type { AuthContext } from './cmap/auth/auth_provider';
246+
export type { AuthContext, AuthProvider } from './cmap/auth/auth_provider';
247247
export type {
248248
AuthMechanismProperties,
249249
MongoCredentials,
@@ -268,6 +268,7 @@ export type {
268268
OpResponseOptions,
269269
WriteProtocolMessageType
270270
} from './cmap/commands';
271+
export type { HandshakeDocument } from './cmap/connect';
271272
export type { LEGAL_TCP_SOCKET_OPTIONS, LEGAL_TLS_SOCKET_OPTIONS, Stream } from './cmap/connect';
272273
export type {
273274
CommandOptions,
@@ -365,6 +366,7 @@ export type {
365366
SupportedTLSSocketOptions,
366367
WithSessionCallback
367368
} from './mongo_client';
369+
export { MongoClientAuthProviders } from './mongo_client_auth_providers';
368370
export type {
369371
Log,
370372
LogComponentSeveritiesClientOptions,

src/mongo_client.ts

+6-2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import { MONGO_CLIENT_EVENTS } from './constants';
2121
import { Db, type DbOptions } from './db';
2222
import type { Encrypter } from './encrypter';
2323
import { MongoInvalidArgumentError } from './error';
24+
import { MongoClientAuthProviders } from './mongo_client_auth_providers';
2425
import {
2526
type LogComponentSeveritiesClientOptions,
2627
type MongoDBLogWritable,
@@ -297,6 +298,7 @@ export interface MongoClientPrivate {
297298
bsonOptions: BSONSerializeOptions;
298299
namespace: MongoDBNamespace;
299300
hasBeenClosed: boolean;
301+
authProviders: MongoClientAuthProviders;
300302
/**
301303
* We keep a reference to the sessions that are acquired from the pool.
302304
* - used to track and close all sessions in client.close() (which is non-standard behavior)
@@ -319,6 +321,7 @@ export type MongoClientEvents = Pick<TopologyEvents, (typeof MONGO_CLIENT_EVENTS
319321
};
320322

321323
/** @internal */
324+
322325
const kOptions = Symbol('options');
323326

324327
/**
@@ -379,6 +382,7 @@ export class MongoClient extends TypedEventEmitter<MongoClientEvents> {
379382
hasBeenClosed: false,
380383
sessionPool: new ServerSessionPool(this),
381384
activeSessions: new Set(),
385+
authProviders: new MongoClientAuthProviders(),
382386

383387
get options() {
384388
return client[kOptions];
@@ -829,10 +833,10 @@ export interface MongoOptions
829833
proxyUsername?: string;
830834
proxyPassword?: string;
831835
serverMonitoringMode: ServerMonitoringMode;
832-
833836
/** @internal */
834837
connectionType?: typeof Connection;
835-
838+
/** @internal */
839+
authProviders: MongoClientAuthProviders;
836840
/** @internal */
837841
encrypter: Encrypter;
838842
/** @internal */

0 commit comments

Comments
 (0)