diff --git a/src/cmap/auth/auth_provider.ts b/src/cmap/auth/auth_provider.ts index 37a47889b91..e40c791ea5d 100644 --- a/src/cmap/auth/auth_provider.ts +++ b/src/cmap/auth/auth_provider.ts @@ -34,6 +34,10 @@ export class AuthContext { } } +/** + * Provider used during authentication. + * @internal + */ export abstract class AuthProvider { /** * Prepare the handshake document before the initial handshake. diff --git a/src/cmap/auth/mongodb_aws.ts b/src/cmap/auth/mongodb_aws.ts index e2b4604084d..b6676656ccf 100644 --- a/src/cmap/auth/mongodb_aws.ts +++ b/src/cmap/auth/mongodb_aws.ts @@ -4,7 +4,7 @@ import { promisify } from 'util'; import type { Binary, BSONSerializeOptions } from '../../bson'; import * as BSON from '../../bson'; -import { aws4, getAwsCredentialProvider } from '../../deps'; +import { aws4, type AWSCredentials, getAwsCredentialProvider } from '../../deps'; import { MongoAWSError, MongoCompatibilityError, @@ -57,12 +57,42 @@ interface AWSSaslContinuePayload { } export class MongoDBAWS extends AuthProvider { - static credentialProvider: ReturnType | null = null; + static credentialProvider: ReturnType; + provider?: () => Promise; randomBytesAsync: (size: number) => Promise; constructor() { super(); this.randomBytesAsync = promisify(crypto.randomBytes); + MongoDBAWS.credentialProvider ??= getAwsCredentialProvider(); + + let { AWS_STS_REGIONAL_ENDPOINTS = '', AWS_REGION = '' } = process.env; + AWS_STS_REGIONAL_ENDPOINTS = AWS_STS_REGIONAL_ENDPOINTS.toLowerCase(); + AWS_REGION = AWS_REGION.toLowerCase(); + + /** The option setting should work only for users who have explicit settings in their environment, the driver should not encode "defaults" */ + const awsRegionSettingsExist = + AWS_REGION.length !== 0 && AWS_STS_REGIONAL_ENDPOINTS.length !== 0; + + /** + * If AWS_STS_REGIONAL_ENDPOINTS is set to regional, users are opting into the new behavior of respecting the region settings + * + * If AWS_STS_REGIONAL_ENDPOINTS is set to legacy, then "old" regions need to keep using the global setting. + * Technically the SDK gets this wrong, it reaches out to 'sts.us-east-1.amazonaws.com' when it should be 'sts.amazonaws.com'. + * That is not our bug to fix here. We leave that up to the SDK. + */ + const useRegionalSts = + AWS_STS_REGIONAL_ENDPOINTS === 'regional' || + (AWS_STS_REGIONAL_ENDPOINTS === 'legacy' && !LEGACY_REGIONS.has(AWS_REGION)); + + if ('fromNodeProviderChain' in MongoDBAWS.credentialProvider) { + this.provider = + awsRegionSettingsExist && useRegionalSts + ? MongoDBAWS.credentialProvider.fromNodeProviderChain({ + clientConfig: { region: AWS_REGION } + }) + : MongoDBAWS.credentialProvider.fromNodeProviderChain(); + } } override async auth(authContext: AuthContext): Promise { @@ -83,7 +113,7 @@ export class MongoDBAWS extends AuthProvider { } if (!authContext.credentials.username) { - authContext.credentials = await makeTempCredentials(authContext.credentials); + authContext.credentials = await makeTempCredentials(authContext.credentials, this.provider); } const { credentials } = authContext; @@ -181,7 +211,10 @@ interface AWSTempCredentials { Expiration?: Date; } -async function makeTempCredentials(credentials: MongoCredentials): Promise { +async function makeTempCredentials( + credentials: MongoCredentials, + provider?: () => Promise +): Promise { function makeMongoCredentialsFromAWSTemp(creds: AWSTempCredentials) { if (!creds.AccessKeyId || !creds.SecretAccessKey || !creds.Token) { throw new MongoMissingCredentialsError('Could not obtain temporary MONGODB-AWS credentials'); @@ -198,11 +231,31 @@ async function makeTempCredentials(credentials: MongoCredentials): Promise([ - [AuthMechanism.MONGODB_AWS, new MongoDBAWS()], - [AuthMechanism.MONGODB_CR, new MongoCR()], - [AuthMechanism.MONGODB_GSSAPI, new GSSAPI()], - [AuthMechanism.MONGODB_OIDC, new MongoDBOIDC()], - [AuthMechanism.MONGODB_PLAIN, new Plain()], - [AuthMechanism.MONGODB_SCRAM_SHA1, new ScramSHA1()], - [AuthMechanism.MONGODB_SCRAM_SHA256, new ScramSHA256()], - [AuthMechanism.MONGODB_X509, new X509()] -]); - /** @public */ export type Stream = Socket | TLSSocket; @@ -111,7 +93,7 @@ export async function performInitialHandshake( if (credentials) { if ( !(credentials.mechanism === AuthMechanism.MONGODB_DEFAULT) && - !AUTH_PROVIDERS.get(credentials.mechanism) + !options.authProviders.getOrCreateProvider(credentials.mechanism) ) { throw new MongoInvalidArgumentError(`AuthMechanism '${credentials.mechanism}' not supported`); } @@ -120,7 +102,7 @@ export async function performInitialHandshake( const authContext = new AuthContext(conn, credentials, options); conn.authContext = authContext; - const handshakeDoc = await prepareHandshakeDocument(authContext); + const handshakeDoc = await prepareHandshakeDocument(authContext, options.authProviders); // @ts-expect-error: TODO(NODE-5141): The options need to be filtered properly, Connection options differ from Command options const handshakeOptions: CommandOptions = { ...options }; @@ -166,7 +148,7 @@ export async function performInitialHandshake( authContext.response = response; const resolvedCredentials = credentials.resolveAuthMechanism(response); - const provider = AUTH_PROVIDERS.get(resolvedCredentials.mechanism); + const provider = options.authProviders.getOrCreateProvider(resolvedCredentials.mechanism); if (!provider) { throw new MongoInvalidArgumentError( `No AuthProvider for ${resolvedCredentials.mechanism} defined.` @@ -191,6 +173,10 @@ export async function performInitialHandshake( conn.established = true; } +/** + * HandshakeDocument used during authentication. + * @internal + */ export interface HandshakeDocument extends Document { /** * @deprecated Use hello instead @@ -210,7 +196,8 @@ export interface HandshakeDocument extends Document { * This function is only exposed for testing purposes. */ export async function prepareHandshakeDocument( - authContext: AuthContext + authContext: AuthContext, + authProviders: MongoClientAuthProviders ): Promise { const options = authContext.options; const compressors = options.compressors ? options.compressors : []; @@ -232,7 +219,7 @@ export async function prepareHandshakeDocument( if (credentials.mechanism === AuthMechanism.MONGODB_DEFAULT && credentials.username) { handshakeDoc.saslSupportedMechs = `${credentials.source}.${credentials.username}`; - const provider = AUTH_PROVIDERS.get(AuthMechanism.MONGODB_SCRAM_SHA256); + const provider = authProviders.getOrCreateProvider(AuthMechanism.MONGODB_SCRAM_SHA256); if (!provider) { // This auth mechanism is always present. throw new MongoInvalidArgumentError( @@ -241,7 +228,7 @@ export async function prepareHandshakeDocument( } return provider.prepare(handshakeDoc, authContext); } - const provider = AUTH_PROVIDERS.get(credentials.mechanism); + const provider = authProviders.getOrCreateProvider(credentials.mechanism); if (!provider) { throw new MongoInvalidArgumentError(`No AuthProvider for ${credentials.mechanism} defined.`); } diff --git a/src/cmap/connection.ts b/src/cmap/connection.ts index 7b277794edc..e33b4f835f6 100644 --- a/src/cmap/connection.ts +++ b/src/cmap/connection.ts @@ -24,6 +24,7 @@ import { MongoWriteConcernError } from '../error'; import type { ServerApi, SupportedNodeConnectionOptions } from '../mongo_client'; +import { type MongoClientAuthProviders } from '../mongo_client_auth_providers'; import { MongoLoggableComponent, type MongoLogger, SeverityLevel } from '../mongo_logger'; import { type CancellationToken, TypedEventEmitter } from '../mongo_types'; import type { ReadPreferenceLike } from '../read_preference'; @@ -109,6 +110,8 @@ export interface ConnectionOptions /** @internal */ connectionType?: any; credentials?: MongoCredentials; + /** @internal */ + authProviders: MongoClientAuthProviders; connectTimeoutMS?: number; tls: boolean; noDelay?: boolean; diff --git a/src/cmap/connection_pool.ts b/src/cmap/connection_pool.ts index 4fe5249738f..b5e0818061c 100644 --- a/src/cmap/connection_pool.ts +++ b/src/cmap/connection_pool.ts @@ -28,7 +28,7 @@ import { import { CancellationToken, TypedEventEmitter } from '../mongo_types'; import type { Server } from '../sdam/server'; import { type Callback, eachAsync, List, makeCounter, TimeoutController } from '../utils'; -import { AUTH_PROVIDERS, connect } from './connect'; +import { connect } from './connect'; import { Connection, type ConnectionEvents, type ConnectionOptions } from './connection'; import { ConnectionCheckedInEvent, @@ -622,7 +622,9 @@ export class ConnectionPool extends TypedEventEmitter { ); } const resolvedCredentials = credentials.resolveAuthMechanism(connection.hello); - const provider = AUTH_PROVIDERS.get(resolvedCredentials.mechanism); + const provider = this[kServer].topology.client.s.authProviders.getOrCreateProvider( + resolvedCredentials.mechanism + ); if (!provider) { return callback( new MongoMissingCredentialsError( @@ -700,7 +702,8 @@ export class ConnectionPool extends TypedEventEmitter { id: this[kConnectionCounter].next().value, generation: this[kGeneration], cancellationToken: this[kCancellationToken], - mongoLogger: this.mongoLogger + mongoLogger: this.mongoLogger, + authProviders: this[kServer].topology.client.s.authProviders }; this[kPending]++; diff --git a/src/index.ts b/src/index.ts index 6366e746655..aae568dd79e 100644 --- a/src/index.ts +++ b/src/index.ts @@ -243,7 +243,7 @@ export type { CSFLEKMSTlsOptions, StateMachineExecutable } from './client-side-encryption/state_machine'; -export type { AuthContext } from './cmap/auth/auth_provider'; +export type { AuthContext, AuthProvider } from './cmap/auth/auth_provider'; export type { AuthMechanismProperties, MongoCredentials, @@ -268,6 +268,7 @@ export type { OpResponseOptions, WriteProtocolMessageType } from './cmap/commands'; +export type { HandshakeDocument } from './cmap/connect'; export type { LEGAL_TCP_SOCKET_OPTIONS, LEGAL_TLS_SOCKET_OPTIONS, Stream } from './cmap/connect'; export type { CommandOptions, @@ -365,6 +366,7 @@ export type { SupportedTLSSocketOptions, WithSessionCallback } from './mongo_client'; +export { MongoClientAuthProviders } from './mongo_client_auth_providers'; export type { Log, LogComponentSeveritiesClientOptions, diff --git a/src/mongo_client.ts b/src/mongo_client.ts index 476a912aece..be039944a4f 100644 --- a/src/mongo_client.ts +++ b/src/mongo_client.ts @@ -21,6 +21,7 @@ import { MONGO_CLIENT_EVENTS } from './constants'; import { Db, type DbOptions } from './db'; import type { Encrypter } from './encrypter'; import { MongoInvalidArgumentError } from './error'; +import { MongoClientAuthProviders } from './mongo_client_auth_providers'; import { type LogComponentSeveritiesClientOptions, type MongoDBLogWritable, @@ -297,6 +298,7 @@ export interface MongoClientPrivate { bsonOptions: BSONSerializeOptions; namespace: MongoDBNamespace; hasBeenClosed: boolean; + authProviders: MongoClientAuthProviders; /** * We keep a reference to the sessions that are acquired from the pool. * - used to track and close all sessions in client.close() (which is non-standard behavior) @@ -319,6 +321,7 @@ export type MongoClientEvents = Pick { hasBeenClosed: false, sessionPool: new ServerSessionPool(this), activeSessions: new Set(), + authProviders: new MongoClientAuthProviders(), get options() { return client[kOptions]; @@ -829,10 +833,10 @@ export interface MongoOptions proxyUsername?: string; proxyPassword?: string; serverMonitoringMode: ServerMonitoringMode; - /** @internal */ connectionType?: typeof Connection; - + /** @internal */ + authProviders: MongoClientAuthProviders; /** @internal */ encrypter: Encrypter; /** @internal */ diff --git a/src/mongo_client_auth_providers.ts b/src/mongo_client_auth_providers.ts new file mode 100644 index 00000000000..557783c4e17 --- /dev/null +++ b/src/mongo_client_auth_providers.ts @@ -0,0 +1,54 @@ +import { type AuthProvider } from './cmap/auth/auth_provider'; +import { GSSAPI } from './cmap/auth/gssapi'; +import { MongoCR } from './cmap/auth/mongocr'; +import { MongoDBAWS } from './cmap/auth/mongodb_aws'; +import { MongoDBOIDC } from './cmap/auth/mongodb_oidc'; +import { Plain } from './cmap/auth/plain'; +import { AuthMechanism } from './cmap/auth/providers'; +import { ScramSHA1, ScramSHA256 } from './cmap/auth/scram'; +import { X509 } from './cmap/auth/x509'; +import { MongoInvalidArgumentError } from './error'; + +/** @internal */ +const AUTH_PROVIDERS = new Map AuthProvider>([ + [AuthMechanism.MONGODB_AWS, () => new MongoDBAWS()], + [AuthMechanism.MONGODB_CR, () => new MongoCR()], + [AuthMechanism.MONGODB_GSSAPI, () => new GSSAPI()], + [AuthMechanism.MONGODB_OIDC, () => new MongoDBOIDC()], + [AuthMechanism.MONGODB_PLAIN, () => new Plain()], + [AuthMechanism.MONGODB_SCRAM_SHA1, () => new ScramSHA1()], + [AuthMechanism.MONGODB_SCRAM_SHA256, () => new ScramSHA256()], + [AuthMechanism.MONGODB_X509, () => new X509()] +]); + +/** + * Create a set of providers per client + * to avoid sharing the provider's cache between different clients. + * @internal + */ +export class MongoClientAuthProviders { + private existingProviders: Map = new Map(); + + /** + * Get or create an authentication provider based on the provided mechanism. + * We don't want to create all providers at once, as some providers may not be used. + * @param name - The name of the provider to get or create. + * @returns The provider. + * @throws MongoInvalidArgumentError if the mechanism is not supported. + * @internal + */ + getOrCreateProvider(name: AuthMechanism | string): AuthProvider { + const authProvider = this.existingProviders.get(name); + if (authProvider) { + return authProvider; + } + + const provider = AUTH_PROVIDERS.get(name)?.(); + if (!provider) { + throw new MongoInvalidArgumentError(`authMechanism ${name} not supported`); + } + + this.existingProviders.set(name, provider); + return provider; + } +} diff --git a/test/integration/auth/mongodb_aws.test.ts b/test/integration/auth/mongodb_aws.test.ts index e075b67775e..635880c04a8 100644 --- a/test/integration/auth/mongodb_aws.test.ts +++ b/test/integration/auth/mongodb_aws.test.ts @@ -67,6 +67,20 @@ describe('MONGODB-AWS', function () { .that.equals(''); }); + it('should store a MongoDBAWS provider instance per client', async function () { + client = this.configuration.newClient(process.env.MONGODB_URI); + + await client + .db('aws') + .collection('aws_test') + .estimatedDocumentCount() + .catch(error => error); + + expect(client).to.have.nested.property('s.authProviders'); + const provider = client.s.authProviders.getOrCreateProvider('MONGODB-AWS'); + expect(provider).to.be.instanceOf(MongoDBAWS); + }); + describe('EC2 with missing credentials', () => { let client; @@ -144,7 +158,6 @@ describe('MONGODB-AWS', function () { }, calledWith: [] }, - { ctx: 'when AWS_STS_REGIONAL_ENDPOINTS is set to regional and region is legacy', title: 'uses the region from the environment', @@ -163,7 +176,6 @@ describe('MONGODB-AWS', function () { }, calledWith: [{ clientConfig: { region: 'sa-east-1' } }] }, - { ctx: 'when AWS_STS_REGIONAL_ENDPOINTS is set to legacy and region is legacy', title: 'uses the region from the environment', @@ -190,6 +202,7 @@ describe('MONGODB-AWS', function () { let storedEnv; let calledArguments; let shouldSkip = false; + let numberOfFromNodeProviderChainCalls; const envCheck = () => { const { AWS_WEB_IDENTITY_TOKEN_FILE = '' } = process.env; @@ -204,8 +217,6 @@ describe('MONGODB-AWS', function () { return this.skip(); } - client = this.configuration.newClient(process.env.MONGODB_URI); - storedEnv = process.env; if (test.env.AWS_STS_REGIONAL_ENDPOINTS === undefined) { delete process.env.AWS_STS_REGIONAL_ENDPOINTS; @@ -218,13 +229,17 @@ describe('MONGODB-AWS', function () { process.env.AWS_REGION = test.env.AWS_REGION; } - calledArguments = []; + numberOfFromNodeProviderChainCalls = 0; + MongoDBAWS.credentialProvider = { fromNodeProviderChain(...args) { calledArguments = args; + numberOfFromNodeProviderChainCalls += 1; return credentialProvider.fromNodeProviderChain(...args); } }; + + client = this.configuration.newClient(process.env.MONGODB_URI); }); afterEach(() => { @@ -253,6 +268,18 @@ describe('MONGODB-AWS', function () { expect(calledArguments).to.deep.equal(test.calledWith); }); + + it('fromNodeProviderChain called once', async function () { + await client.close(); + await client.connect(); + await client + .db('aws') + .collection('aws_test') + .estimatedDocumentCount() + .catch(error => error); + + expect(numberOfFromNodeProviderChainCalls).to.be.eql(1); + }); }); } }); diff --git a/test/integration/connection-monitoring-and-pooling/connection.test.ts b/test/integration/connection-monitoring-and-pooling/connection.test.ts index e702795b40b..a1e8f1f9571 100644 --- a/test/integration/connection-monitoring-and-pooling/connection.test.ts +++ b/test/integration/connection-monitoring-and-pooling/connection.test.ts @@ -8,6 +8,7 @@ import { LEGACY_HELLO_COMMAND, makeClientMetadata, MongoClient, + MongoClientAuthProviders, MongoServerError, ns, ServerHeartbeatStartedEvent, @@ -23,7 +24,8 @@ const commonConnectOptions = { tls: false, loadBalanced: false, // Will be overridden by configuration options - hostAddress: HostAddress.fromString('127.0.0.1:1') + hostAddress: HostAddress.fromString('127.0.0.1:1'), + authProviders: new MongoClientAuthProviders() }; describe('Connection', function () { diff --git a/test/unit/cmap/connect.test.ts b/test/unit/cmap/connect.test.ts index 7f69f54d174..7697c124fbe 100644 --- a/test/unit/cmap/connect.test.ts +++ b/test/unit/cmap/connect.test.ts @@ -9,6 +9,7 @@ import { HostAddress, isHello, LEGACY_HELLO_COMMAND, + MongoClientAuthProviders, MongoCredentials, MongoNetworkError, prepareHandshakeDocument @@ -44,7 +45,8 @@ describe('Connect Tests', function () { source: 'admin', mechanism: 'PLAIN', mechanismProperties: {} - }) + }), + authProviders: new MongoClientAuthProviders() }; }); diff --git a/test/unit/cmap/connection.test.ts b/test/unit/cmap/connection.test.ts index ec902fffd3d..9127d68c99c 100644 --- a/test/unit/cmap/connection.test.ts +++ b/test/unit/cmap/connection.test.ts @@ -1,7 +1,14 @@ import { expect } from 'chai'; import * as sinon from 'sinon'; -import { connect, Connection, isHello, MongoNetworkTimeoutError, ns } from '../../mongodb'; +import { + connect, + Connection, + isHello, + MongoClientAuthProviders, + MongoNetworkTimeoutError, + ns +} from '../../mongodb'; import * as mock from '../../tools/mongodb-mock/index'; import { getSymbolFrom } from '../../tools/utils'; @@ -32,7 +39,8 @@ describe('new Connection()', function () { const options = { ...connectionOptionsDefaults, connectionType: Connection, - hostAddress: server.hostAddress() + hostAddress: server.hostAddress(), + authProviders: new MongoClientAuthProviders() }; const conn = await connect(options); @@ -54,7 +62,8 @@ describe('new Connection()', function () { const options = { ...connectionOptionsDefaults, connectionType: Connection, - hostAddress: server.hostAddress() + hostAddress: server.hostAddress(), + authProviders: new MongoClientAuthProviders() }; const conn = await connect(options); @@ -76,7 +85,8 @@ describe('new Connection()', function () { const options = { hostAddress: server.hostAddress(), - ...connectionOptionsDefaults + ...connectionOptionsDefaults, + authProviders: new MongoClientAuthProviders() }; const conn = await connect(options); @@ -101,7 +111,8 @@ describe('new Connection()', function () { const options = { ...connectionOptionsDefaults, - hostAddress: server.hostAddress() + hostAddress: server.hostAddress(), + authProviders: new MongoClientAuthProviders() }; const connection = await connect(options); @@ -119,7 +130,8 @@ describe('new Connection()', function () { const options = { ...connectionOptionsDefaults, hostAddress: server.hostAddress(), - socketTimeoutMS: 50 + socketTimeoutMS: 50, + authProviders: new MongoClientAuthProviders() }; const error = await connect(options).catch(error => error); diff --git a/test/unit/cmap/connection_pool.test.js b/test/unit/cmap/connection_pool.test.js index cdbf00bf67f..43177b72962 100644 --- a/test/unit/cmap/connection_pool.test.js +++ b/test/unit/cmap/connection_pool.test.js @@ -11,6 +11,7 @@ const { ns, isHello } = require('../../mongodb'); const { LEGACY_HELLO_COMMAND } = require('../../mongodb'); const { createTimerSandbox } = require('../timer_sandbox'); const { topologyWithPlaceholderClient } = require('../../tools/utils'); +const { MongoClientAuthProviders } = require('../../mongodb'); describe('Connection Pool', function () { let mockMongod; @@ -20,6 +21,9 @@ describe('Connection Pool', function () { mongoLogger: { debug: () => null, willLog: () => null + }, + s: { + authProviders: new MongoClientAuthProviders() } } } diff --git a/test/unit/index.test.ts b/test/unit/index.test.ts index 6e36a54fa79..508f3d85c2a 100644 --- a/test/unit/index.test.ts +++ b/test/unit/index.test.ts @@ -71,6 +71,7 @@ const EXPECTED_EXPORTS = [ 'MongoBulkWriteError', 'MongoChangeStreamError', 'MongoClient', + 'MongoClientAuthProviders', 'MongoCompatibilityError', 'MongoCryptAzureKMSRequestError', 'MongoCryptCreateDataKeyError', diff --git a/test/unit/mongo_client.test.js b/test/unit/mongo_client.test.js index c9c3c9923fc..04ef86a5038 100644 --- a/test/unit/mongo_client.test.js +++ b/test/unit/mongo_client.test.js @@ -11,14 +11,20 @@ const { ReadConcern } = require('../mongodb'); const { WriteConcern } = require('../mongodb'); const { ReadPreference } = require('../mongodb'); const { MongoCredentials } = require('../mongodb'); -const { MongoClient, MongoParseError, ServerApiVersion, MongoAPIError } = require('../mongodb'); +const { + MongoClient, + MongoParseError, + ServerApiVersion, + MongoAPIError, + MongoInvalidArgumentError +} = require('../mongodb'); const { MongoLogger } = require('../mongodb'); // eslint-disable-next-line no-restricted-modules const { SeverityLevel, MongoLoggableComponent } = require('../../src/mongo_logger'); const sinon = require('sinon'); const { Writable } = require('stream'); -describe('MongoOptions', function () { +describe('MongoClient', function () { it('MongoClient should always freeze public options', function () { const client = new MongoClient('mongodb://localhost:27017'); expect(client.options).to.be.frozen; @@ -1182,4 +1188,15 @@ describe('MongoOptions', function () { }); }); }); + + context('getAuthProvider', function () { + it('throws MongoInvalidArgumentError if provided authMechanism is not supported', function () { + const client = new MongoClient('mongodb://localhost:27017'); + try { + client.s.authProviders.getOrCreateProvider('NOT_SUPPORTED'); + } catch (error) { + expect(error).to.be.an.instanceof(MongoInvalidArgumentError); + } + }); + }); });