From 0db3da4636b6a4e409648f7b93fb896a9abedf99 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Thu, 6 Feb 2025 13:33:57 +0000 Subject: [PATCH 01/25] Install pkce-challenge library --- package-lock.json | 14 ++++++++++++-- package.json | 1 + 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/package-lock.json b/package-lock.json index 687a7c0c..f09bdc2c 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,16 +1,17 @@ { "name": "@modelcontextprotocol/sdk", - "version": "1.3.2", + "version": "1.4.1", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@modelcontextprotocol/sdk", - "version": "1.3.2", + "version": "1.4.1", "license": "MIT", "dependencies": { "content-type": "^1.0.5", "eventsource": "^3.0.2", + "pkce-challenge": "^4.1.0", "raw-body": "^3.0.0", "zod": "^3.23.8", "zod-to-json-schema": "^3.24.1" @@ -5058,6 +5059,15 @@ "node": ">= 6" } }, + "node_modules/pkce-challenge": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/pkce-challenge/-/pkce-challenge-4.1.0.tgz", + "integrity": "sha512-ZBmhE1C9LcPoH9XZSdwiPtbPHZROwAnMy+kIFQVrnMCxY4Cudlz3gBOpzilgc0jOgRaiT3sIWfpMomW2ar2orQ==", + "license": "MIT", + "engines": { + "node": ">=16.20.0" + } + }, "node_modules/pkg-dir": { "version": "4.2.0", "resolved": "https://registry.npmjs.org/pkg-dir/-/pkg-dir-4.2.0.tgz", diff --git a/package.json b/package.json index 77c268ed..66d28112 100644 --- a/package.json +++ b/package.json @@ -48,6 +48,7 @@ "dependencies": { "content-type": "^1.0.5", "eventsource": "^3.0.2", + "pkce-challenge": "^4.1.0", "raw-body": "^3.0.0", "zod": "^3.23.8", "zod-to-json-schema": "^3.24.1" From 0bc3a74ac32c5d6a1b2607e4f5dbcbbd72a1810e Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Thu, 6 Feb 2025 13:38:21 +0000 Subject: [PATCH 02/25] OAuth metadata discovery --- src/client/auth/auth.ts | 58 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 src/client/auth/auth.ts diff --git a/src/client/auth/auth.ts b/src/client/auth/auth.ts new file mode 100644 index 00000000..842abe9d --- /dev/null +++ b/src/client/auth/auth.ts @@ -0,0 +1,58 @@ +import { z } from "zod"; + +export const OAuthMetadataSchema = z + .object({ + issuer: z.string(), + authorization_endpoint: z.string(), + token_endpoint: z.string(), + registration_endpoint: z.string().optional(), + scopes_supported: z.array(z.string()).optional(), + response_types_supported: z.array(z.string()), + response_modes_supported: z.array(z.string()).optional(), + grant_types_supported: z.array(z.string()).optional(), + token_endpoint_auth_methods_supported: z.array(z.string()).optional(), + token_endpoint_auth_signing_alg_values_supported: z + .array(z.string()) + .optional(), + service_documentation: z.string().optional(), + revocation_endpoint: z.string().optional(), + revocation_endpoint_auth_methods_supported: z.array(z.string()).optional(), + revocation_endpoint_auth_signing_alg_values_supported: z + .array(z.string()) + .optional(), + introspection_endpoint: z.string().optional(), + introspection_endpoint_auth_methods_supported: z + .array(z.string()) + .optional(), + introspection_endpoint_auth_signing_alg_values_supported: z + .array(z.string()) + .optional(), + code_challenge_methods_supported: z.array(z.string()).optional(), + }) + .passthrough(); + +export type OAuthMetadata = z.infer; + +/** + * Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata. + * + * If the server returns a 404 for the well-known endpoint, this function will + * return `undefined`. Any other errors will be thrown as exceptions. + */ +export async function discoverOAuthMetadata( + serverUrl: string | URL, +): Promise { + const url = new URL("/.well-known/oauth-authorization-server", serverUrl); + const response = await fetch(url); + if (response.status === 404) { + return undefined; + } + + if (!response.ok) { + throw new Error( + `HTTP ${response.status} trying to load well-known OAuth metadata`, + ); + } + + return OAuthMetadataSchema.parse(await response.json()); +} From c39c8bf62b7286437571a6f2a93a39888cb63fed Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Thu, 6 Feb 2025 13:44:41 +0000 Subject: [PATCH 03/25] WIP start authorization flow --- src/client/auth/auth.ts | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/client/auth/auth.ts b/src/client/auth/auth.ts index 842abe9d..4758b1f6 100644 --- a/src/client/auth/auth.ts +++ b/src/client/auth/auth.ts @@ -1,3 +1,4 @@ +import pkceChallenge from "pkce-challenge"; import { z } from "zod"; export const OAuthMetadataSchema = z @@ -56,3 +57,28 @@ export async function discoverOAuthMetadata( return OAuthMetadataSchema.parse(await response.json()); } + +export async function startAuthorization( + serverUrl: string | URL, + { + metadata, + redirectUrl, + }: { metadata: OAuthMetadata; redirectUrl: string | URL }, +): Promise<{ authorizationUrl: URL; codeVerifier: string }> { + // Generate PKCE challenge + const challenge = await pkceChallenge(); + const codeVerifier = challenge.code_verifier; + const codeChallenge = challenge.code_challenge; + + const authorizationUrl = metadata?.authorization_endpoint + ? new URL(metadata?.authorization_endpoint) + : new URL("/authorize", serverUrl); + + // TODO: Validate that these parameters are listed as supported in the metadata, if present. + authorizationUrl.searchParams.set("response_type", "code"); + authorizationUrl.searchParams.set("code_challenge", codeChallenge); + authorizationUrl.searchParams.set("code_challenge_method", "S256"); + authorizationUrl.searchParams.set("redirect_uri", String(redirectUrl)); + + return { authorizationUrl, codeVerifier }; +} From 8a26e2d9e82b9dee7d27239738a159d80c70292a Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Thu, 6 Feb 2025 14:39:50 +0000 Subject: [PATCH 04/25] Validate auth flow is supported --- src/client/auth/auth.ts | 35 +++++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/src/client/auth/auth.ts b/src/client/auth/auth.ts index 4758b1f6..32b68b76 100644 --- a/src/client/auth/auth.ts +++ b/src/client/auth/auth.ts @@ -70,14 +70,37 @@ export async function startAuthorization( const codeVerifier = challenge.code_verifier; const codeChallenge = challenge.code_challenge; - const authorizationUrl = metadata?.authorization_endpoint - ? new URL(metadata?.authorization_endpoint) - : new URL("/authorize", serverUrl); + const responseType = "code"; + const codeChallengeMethod = "S256"; - // TODO: Validate that these parameters are listed as supported in the metadata, if present. - authorizationUrl.searchParams.set("response_type", "code"); + let authorizationUrl: URL; + if (metadata) { + authorizationUrl = new URL(metadata.authorization_endpoint); + + if (!(responseType in metadata.response_types_supported)) { + throw new Error( + `Incompatible auth server: does not support response type ${responseType}`, + ); + } + + if ( + !metadata.code_challenge_methods_supported || + !(codeChallengeMethod in metadata.code_challenge_methods_supported) + ) { + throw new Error( + `Incompatible auth server: does not support code challenge method ${codeChallengeMethod}`, + ); + } + } else { + authorizationUrl = new URL("/authorize", serverUrl); + } + + authorizationUrl.searchParams.set("response_type", responseType); authorizationUrl.searchParams.set("code_challenge", codeChallenge); - authorizationUrl.searchParams.set("code_challenge_method", "S256"); + authorizationUrl.searchParams.set( + "code_challenge_method", + codeChallengeMethod, + ); authorizationUrl.searchParams.set("redirect_uri", String(redirectUrl)); return { authorizationUrl, codeVerifier }; From 956f6658e8989ac7c84356ee8437ded48eddbdbe Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Thu, 6 Feb 2025 14:52:37 +0000 Subject: [PATCH 05/25] Implement token exchange --- src/client/auth/auth.ts | 80 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 75 insertions(+), 5 deletions(-) diff --git a/src/client/auth/auth.ts b/src/client/auth/auth.ts index 32b68b76..8cd3315f 100644 --- a/src/client/auth/auth.ts +++ b/src/client/auth/auth.ts @@ -32,7 +32,18 @@ export const OAuthMetadataSchema = z }) .passthrough(); +export const OAuthTokensSchema = z + .object({ + access_token: z.string(), + token_type: z.string(), + expires_in: z.number().optional(), + scope: z.string().optional(), + refresh_token: z.string().optional(), + }) + .strip(); + export type OAuthMetadata = z.infer; +export type OAuthTokens = z.infer; /** * Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata. @@ -58,6 +69,9 @@ export async function discoverOAuthMetadata( return OAuthMetadataSchema.parse(await response.json()); } +/** + * Begins the authorization flow with the given server, by generating a PKCE challenge and constructing the authorization URL. + */ export async function startAuthorization( serverUrl: string | URL, { @@ -65,11 +79,6 @@ export async function startAuthorization( redirectUrl, }: { metadata: OAuthMetadata; redirectUrl: string | URL }, ): Promise<{ authorizationUrl: URL; codeVerifier: string }> { - // Generate PKCE challenge - const challenge = await pkceChallenge(); - const codeVerifier = challenge.code_verifier; - const codeChallenge = challenge.code_challenge; - const responseType = "code"; const codeChallengeMethod = "S256"; @@ -95,6 +104,11 @@ export async function startAuthorization( authorizationUrl = new URL("/authorize", serverUrl); } + // Generate PKCE challenge + const challenge = await pkceChallenge(); + const codeVerifier = challenge.code_verifier; + const codeChallenge = challenge.code_challenge; + authorizationUrl.searchParams.set("response_type", responseType); authorizationUrl.searchParams.set("code_challenge", codeChallenge); authorizationUrl.searchParams.set( @@ -105,3 +119,59 @@ export async function startAuthorization( return { authorizationUrl, codeVerifier }; } + +/** + * Exchanges an authorization code for an access token with the given server. + */ +export async function exchangeAuthorization( + serverUrl: string | URL, + { + metadata, + authorizationCode, + codeVerifier, + redirectUrl, + }: { + metadata: OAuthMetadata; + authorizationCode: string; + codeVerifier: string; + redirectUrl: string | URL; + }, +): Promise { + const grantType = "authorization_code"; + + let tokenUrl: URL; + if (metadata) { + tokenUrl = new URL(metadata.token_endpoint); + + if ( + metadata.grant_types_supported && + !(grantType in metadata.grant_types_supported) + ) { + throw new Error( + `Incompatible auth server: does not support grant type ${grantType}`, + ); + } + } else { + tokenUrl = new URL("/token", serverUrl); + } + + // Exchange code for tokens + const response = await fetch(tokenUrl, { + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: new URLSearchParams({ + grant_type: grantType, + code: authorizationCode, + code_verifier: codeVerifier, + redirect_uri: String(redirectUrl), + }), + }); + + if (!response.ok) { + throw new Error(`Token exchange failed: HTTP ${response.status}`); + } + + return OAuthTokensSchema.parse(await response.json()); +} From fe24aa5c26f4270fc02e9a772a55166975b1a275 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Thu, 6 Feb 2025 16:34:33 +0000 Subject: [PATCH 06/25] Remove nested folder for auth --- src/client/{auth => }/auth.ts | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/client/{auth => }/auth.ts (100%) diff --git a/src/client/auth/auth.ts b/src/client/auth.ts similarity index 100% rename from src/client/auth/auth.ts rename to src/client/auth.ts From 62d67ad874f4ded2cd5bb5a4f90ddc3f0a5d037c Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Thu, 6 Feb 2025 16:37:47 +0000 Subject: [PATCH 07/25] Add refresh token support --- src/client/auth.ts | 52 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 49 insertions(+), 3 deletions(-) diff --git a/src/client/auth.ts b/src/client/auth.ts index 8cd3315f..01bf8471 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -129,12 +129,10 @@ export async function exchangeAuthorization( metadata, authorizationCode, codeVerifier, - redirectUrl, }: { metadata: OAuthMetadata; authorizationCode: string; codeVerifier: string; - redirectUrl: string | URL; }, ): Promise { const grantType = "authorization_code"; @@ -165,7 +163,55 @@ export async function exchangeAuthorization( grant_type: grantType, code: authorizationCode, code_verifier: codeVerifier, - redirect_uri: String(redirectUrl), + }), + }); + + if (!response.ok) { + throw new Error(`Token exchange failed: HTTP ${response.status}`); + } + + return OAuthTokensSchema.parse(await response.json()); +} + +/** + * Exchange a refresh token for an updated access token. + */ +export async function refreshAuthorization( + serverUrl: string | URL, + { + metadata, + refreshToken, + }: { + metadata: OAuthMetadata; + refreshToken: string; + }, +): Promise { + const grantType = "refresh_token"; + + let tokenUrl: URL; + if (metadata) { + tokenUrl = new URL(metadata.token_endpoint); + + if ( + metadata.grant_types_supported && + !(grantType in metadata.grant_types_supported) + ) { + throw new Error( + `Incompatible auth server: does not support grant type ${grantType}`, + ); + } + } else { + tokenUrl = new URL("/token", serverUrl); + } + + const response = await fetch(tokenUrl, { + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: new URLSearchParams({ + grant_type: grantType, + refresh_token: refreshToken, }), }); From 2b9a58182c8410cf81d523f4739fcf57da08afe1 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 11 Feb 2025 13:32:10 +0000 Subject: [PATCH 08/25] Add dynamic client registration --- src/client/auth.ts | 87 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 84 insertions(+), 3 deletions(-) diff --git a/src/client/auth.ts b/src/client/auth.ts index 01bf8471..ea68914b 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -42,9 +42,43 @@ export const OAuthTokensSchema = z }) .strip(); +/** + * Client metadata schema according to RFC 7591 OAuth 2.0 Dynamic Client Registration + */ +export const ClientMetadataSchema = z.object({ + redirect_uris: z.array(z.string()), + token_endpoint_auth_method: z.string().optional(), + grant_types: z.array(z.string()).optional(), + response_types: z.array(z.string()).optional(), + client_name: z.string().optional(), + client_uri: z.string().optional(), + logo_uri: z.string().optional(), + scope: z.string().optional(), + contacts: z.array(z.string()).optional(), + tos_uri: z.string().optional(), + policy_uri: z.string().optional(), + jwks_uri: z.string().optional(), + jwks: z.any().optional(), + software_id: z.string().optional(), + software_version: z.string().optional(), +}).passthrough(); + +/** + * Client information response schema according to RFC 7591 + */ +export const ClientInformationSchema = z.object({ + client_id: z.string(), + client_secret: z.string().optional(), + client_id_issued_at: z.number().optional(), + client_secret_expires_at: z.number().optional(), +}).merge(ClientMetadataSchema); + export type OAuthMetadata = z.infer; export type OAuthTokens = z.infer; +export type ClientMetadata = z.infer; +export type ClientInformation = z.infer; + /** * Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata. * @@ -77,7 +111,7 @@ export async function startAuthorization( { metadata, redirectUrl, - }: { metadata: OAuthMetadata; redirectUrl: string | URL }, + }: { metadata?: OAuthMetadata; redirectUrl: string | URL }, ): Promise<{ authorizationUrl: URL; codeVerifier: string }> { const responseType = "code"; const codeChallengeMethod = "S256"; @@ -130,7 +164,7 @@ export async function exchangeAuthorization( authorizationCode, codeVerifier, }: { - metadata: OAuthMetadata; + metadata?: OAuthMetadata; authorizationCode: string; codeVerifier: string; }, @@ -182,7 +216,7 @@ export async function refreshAuthorization( metadata, refreshToken, }: { - metadata: OAuthMetadata; + metadata?: OAuthMetadata; refreshToken: string; }, ): Promise { @@ -221,3 +255,50 @@ export async function refreshAuthorization( return OAuthTokensSchema.parse(await response.json()); } + +/** + * Performs OAuth 2.0 Dynamic Client Registration according to RFC 7591. + * + * @param serverUrl - The base URL of the authorization server + * @param options - Registration options + * @param options.metadata - OAuth server metadata containing the registration endpoint + * @param options.clientMetadata - Client metadata for registration + * @returns The registered client information + * @throws Error if the server doesn't support dynamic registration or if registration fails + */ +export async function registerClient( + serverUrl: string | URL, + { + metadata, + clientMetadata, + }: { + metadata?: OAuthMetadata; + clientMetadata: ClientMetadata; + }, +): Promise { + let registrationUrl: URL; + + if (metadata) { + if (!metadata.registration_endpoint) { + throw new Error("Incompatible auth server: does not support dynamic client registration"); + } + + registrationUrl = new URL(metadata.registration_endpoint); + } else { + registrationUrl = new URL("/register", serverUrl); + } + + const response = await fetch(registrationUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(clientMetadata), + }); + + if (!response.ok) { + throw new Error(`Dynamic client registration failed: HTTP ${response.status}`); + } + + return ClientInformationSchema.parse(await response.json()); +} From d17a382a65f3026f16375ca416690d8f624cdf82 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 11 Feb 2025 13:47:58 +0000 Subject: [PATCH 09/25] Auth tests and fixes --- src/client/auth.test.ts | 391 ++++++++++++++++++++++++++++++++++++++++ src/client/auth.ts | 10 +- 2 files changed, 396 insertions(+), 5 deletions(-) create mode 100644 src/client/auth.test.ts diff --git a/src/client/auth.test.ts b/src/client/auth.test.ts new file mode 100644 index 00000000..88a35082 --- /dev/null +++ b/src/client/auth.test.ts @@ -0,0 +1,391 @@ +import { + discoverOAuthMetadata, + startAuthorization, + exchangeAuthorization, + refreshAuthorization, + registerClient, +} from "./auth"; + +// Mock pkce-challenge +jest.mock("pkce-challenge", () => ({ + __esModule: true, + default: () => ({ + code_verifier: "test_verifier", + code_challenge: "test_challenge", + }), +})); + +// Mock fetch globally +const mockFetch = jest.fn(); +global.fetch = mockFetch; + +describe("OAuth Authorization", () => { + beforeEach(() => { + mockFetch.mockReset(); + }); + + describe("discoverOAuthMetadata", () => { + const validMetadata = { + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + registration_endpoint: "https://auth.example.com/register", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }; + + it("returns metadata when discovery succeeds", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validMetadata, + }); + + const metadata = await discoverOAuthMetadata("https://auth.example.com"); + expect(metadata).toEqual(validMetadata); + expect(mockFetch).toHaveBeenCalledWith( + expect.objectContaining({ + href: "https://auth.example.com/.well-known/oauth-authorization-server", + }) + ); + }); + + it("returns undefined when discovery endpoint returns 404", async () => { + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404, + }); + + const metadata = await discoverOAuthMetadata("https://auth.example.com"); + expect(metadata).toBeUndefined(); + }); + + it("throws on non-404 errors", async () => { + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 500, + }); + + await expect( + discoverOAuthMetadata("https://auth.example.com") + ).rejects.toThrow("HTTP 500"); + }); + + it("validates metadata schema", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + // Missing required fields + issuer: "https://auth.example.com", + }), + }); + + await expect( + discoverOAuthMetadata("https://auth.example.com") + ).rejects.toThrow(); + }); + }); + + describe("startAuthorization", () => { + const validMetadata = { + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }; + + it("generates authorization URL with PKCE challenge", async () => { + const { authorizationUrl, codeVerifier } = await startAuthorization( + "https://auth.example.com", + { + redirectUrl: "http://localhost:3000/callback", + } + ); + + expect(authorizationUrl.toString()).toMatch( + /^https:\/\/auth\.example\.com\/authorize\?/ + ); + expect(authorizationUrl.searchParams.get("response_type")).toBe("code"); + expect(authorizationUrl.searchParams.get("code_challenge")).toBe("test_challenge"); + expect(authorizationUrl.searchParams.get("code_challenge_method")).toBe( + "S256" + ); + expect(authorizationUrl.searchParams.get("redirect_uri")).toBe( + "http://localhost:3000/callback" + ); + expect(codeVerifier).toBe("test_verifier"); + }); + + it("uses metadata authorization_endpoint when provided", async () => { + const { authorizationUrl } = await startAuthorization( + "https://auth.example.com", + { + metadata: validMetadata, + redirectUrl: "http://localhost:3000/callback", + } + ); + + expect(authorizationUrl.toString()).toMatch( + /^https:\/\/auth\.example\.com\/authorize\?/ + ); + }); + + it("validates response type support", async () => { + const metadata = { + ...validMetadata, + response_types_supported: ["token"], // Does not support 'code' + }; + + await expect( + startAuthorization("https://auth.example.com", { + metadata, + redirectUrl: "http://localhost:3000/callback", + }) + ).rejects.toThrow(/does not support response type/); + }); + + it("validates PKCE support", async () => { + const metadata = { + ...validMetadata, + response_types_supported: ["code"], + code_challenge_methods_supported: ["plain"], // Does not support 'S256' + }; + + await expect( + startAuthorization("https://auth.example.com", { + metadata, + redirectUrl: "http://localhost:3000/callback", + }) + ).rejects.toThrow(/does not support code challenge method/); + }); + }); + + describe("exchangeAuthorization", () => { + const validTokens = { + access_token: "access123", + token_type: "Bearer", + expires_in: 3600, + refresh_token: "refresh123", + }; + + it("exchanges code for tokens", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens, + }); + + const tokens = await exchangeAuthorization("https://auth.example.com", { + authorizationCode: "code123", + codeVerifier: "verifier123", + }); + + expect(tokens).toEqual(validTokens); + expect(mockFetch).toHaveBeenCalledWith( + expect.objectContaining({ + href: "https://auth.example.com/token", + }), + expect.objectContaining({ + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + }) + ); + + const body = mockFetch.mock.calls[0][1].body as URLSearchParams; + expect(body.get("grant_type")).toBe("authorization_code"); + expect(body.get("code")).toBe("code123"); + expect(body.get("code_verifier")).toBe("verifier123"); + }); + + it("validates token response schema", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + // Missing required fields + access_token: "access123", + }), + }); + + await expect( + exchangeAuthorization("https://auth.example.com", { + authorizationCode: "code123", + codeVerifier: "verifier123", + }) + ).rejects.toThrow(); + }); + + it("throws on error response", async () => { + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 400, + }); + + await expect( + exchangeAuthorization("https://auth.example.com", { + authorizationCode: "code123", + codeVerifier: "verifier123", + }) + ).rejects.toThrow("Token exchange failed"); + }); + }); + + describe("refreshAuthorization", () => { + const validTokens = { + access_token: "newaccess123", + token_type: "Bearer", + expires_in: 3600, + refresh_token: "newrefresh123", + }; + + it("exchanges refresh token for new tokens", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens, + }); + + const tokens = await refreshAuthorization("https://auth.example.com", { + refreshToken: "refresh123", + }); + + expect(tokens).toEqual(validTokens); + expect(mockFetch).toHaveBeenCalledWith( + expect.objectContaining({ + href: "https://auth.example.com/token", + }), + expect.objectContaining({ + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + }) + ); + + const body = mockFetch.mock.calls[0][1].body as URLSearchParams; + expect(body.get("grant_type")).toBe("refresh_token"); + expect(body.get("refresh_token")).toBe("refresh123"); + }); + + it("validates token response schema", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + // Missing required fields + access_token: "newaccess123", + }), + }); + + await expect( + refreshAuthorization("https://auth.example.com", { + refreshToken: "refresh123", + }) + ).rejects.toThrow(); + }); + + it("throws on error response", async () => { + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 400, + }); + + await expect( + refreshAuthorization("https://auth.example.com", { + refreshToken: "refresh123", + }) + ).rejects.toThrow("Token exchange failed"); + }); + }); + + describe("registerClient", () => { + const validClientMetadata = { + redirect_uris: ["http://localhost:3000/callback"], + client_name: "Test Client", + }; + + const validClientInfo = { + client_id: "client123", + client_secret: "secret123", + client_id_issued_at: 1612137600, + client_secret_expires_at: 1612224000, + ...validClientMetadata, + }; + + it("registers client and returns client information", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validClientInfo, + }); + + const clientInfo = await registerClient("https://auth.example.com", { + clientMetadata: validClientMetadata, + }); + + expect(clientInfo).toEqual(validClientInfo); + expect(mockFetch).toHaveBeenCalledWith( + expect.objectContaining({ + href: "https://auth.example.com/register", + }), + expect.objectContaining({ + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(validClientMetadata), + }) + ); + }); + + it("validates client information response schema", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + // Missing required fields + client_secret: "secret123", + }), + }); + + await expect( + registerClient("https://auth.example.com", { + clientMetadata: validClientMetadata, + }) + ).rejects.toThrow(); + }); + + it("throws when registration endpoint not available in metadata", async () => { + const metadata = { + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + response_types_supported: ["code"], + }; + + await expect( + registerClient("https://auth.example.com", { + metadata, + clientMetadata: validClientMetadata, + }) + ).rejects.toThrow(/does not support dynamic client registration/); + }); + + it("throws on error response", async () => { + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 400, + }); + + await expect( + registerClient("https://auth.example.com", { + clientMetadata: validClientMetadata, + }) + ).rejects.toThrow("Dynamic client registration failed"); + }); + }); +}); \ No newline at end of file diff --git a/src/client/auth.ts b/src/client/auth.ts index ea68914b..78c6545a 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -120,7 +120,7 @@ export async function startAuthorization( if (metadata) { authorizationUrl = new URL(metadata.authorization_endpoint); - if (!(responseType in metadata.response_types_supported)) { + if (!metadata.response_types_supported.includes(responseType)) { throw new Error( `Incompatible auth server: does not support response type ${responseType}`, ); @@ -128,7 +128,7 @@ export async function startAuthorization( if ( !metadata.code_challenge_methods_supported || - !(codeChallengeMethod in metadata.code_challenge_methods_supported) + !metadata.code_challenge_methods_supported.includes(codeChallengeMethod) ) { throw new Error( `Incompatible auth server: does not support code challenge method ${codeChallengeMethod}`, @@ -177,7 +177,7 @@ export async function exchangeAuthorization( if ( metadata.grant_types_supported && - !(grantType in metadata.grant_types_supported) + !metadata.grant_types_supported.includes(grantType) ) { throw new Error( `Incompatible auth server: does not support grant type ${grantType}`, @@ -228,7 +228,7 @@ export async function refreshAuthorization( if ( metadata.grant_types_supported && - !(grantType in metadata.grant_types_supported) + !metadata.grant_types_supported.includes(grantType) ) { throw new Error( `Incompatible auth server: does not support grant type ${grantType}`, @@ -301,4 +301,4 @@ export async function registerClient( } return ClientInformationSchema.parse(await response.json()); -} +} \ No newline at end of file From 4c87fb50abfb8f914f5ccea9d7170cc9490c67ef Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 11 Feb 2025 14:17:02 +0000 Subject: [PATCH 10/25] Prefix client metadata + info with `OAuth` --- src/client/auth.ts | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/client/auth.ts b/src/client/auth.ts index 78c6545a..57fed8f2 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -45,7 +45,7 @@ export const OAuthTokensSchema = z /** * Client metadata schema according to RFC 7591 OAuth 2.0 Dynamic Client Registration */ -export const ClientMetadataSchema = z.object({ +export const OAuthClientMetadataSchema = z.object({ redirect_uris: z.array(z.string()), token_endpoint_auth_method: z.string().optional(), grant_types: z.array(z.string()).optional(), @@ -66,18 +66,18 @@ export const ClientMetadataSchema = z.object({ /** * Client information response schema according to RFC 7591 */ -export const ClientInformationSchema = z.object({ +export const OAuthClientInformationSchema = z.object({ client_id: z.string(), client_secret: z.string().optional(), client_id_issued_at: z.number().optional(), client_secret_expires_at: z.number().optional(), -}).merge(ClientMetadataSchema); +}).merge(OAuthClientMetadataSchema); export type OAuthMetadata = z.infer; export type OAuthTokens = z.infer; -export type ClientMetadata = z.infer; -export type ClientInformation = z.infer; +export type OAuthClientMetadata = z.infer; +export type OAuthClientInformation = z.infer; /** * Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata. @@ -273,9 +273,9 @@ export async function registerClient( clientMetadata, }: { metadata?: OAuthMetadata; - clientMetadata: ClientMetadata; + clientMetadata: OAuthClientMetadata; }, -): Promise { +): Promise { let registrationUrl: URL; if (metadata) { @@ -300,5 +300,5 @@ export async function registerClient( throw new Error(`Dynamic client registration failed: HTTP ${response.status}`); } - return ClientInformationSchema.parse(await response.json()); + return OAuthClientInformationSchema.parse(await response.json()); } \ No newline at end of file From fa50204a11afc3b5d18bcf35866a7edde0c74251 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 11 Feb 2025 14:17:10 +0000 Subject: [PATCH 11/25] Fix import of auth.js in test --- src/client/auth.test.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/client/auth.test.ts b/src/client/auth.test.ts index 88a35082..39be8216 100644 --- a/src/client/auth.test.ts +++ b/src/client/auth.test.ts @@ -4,7 +4,7 @@ import { exchangeAuthorization, refreshAuthorization, registerClient, -} from "./auth"; +} from "./auth.js"; // Mock pkce-challenge jest.mock("pkce-challenge", () => ({ From 47fd6df229f1658a565f6c7b838f2d29458639f6 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 11 Feb 2025 15:13:17 +0000 Subject: [PATCH 12/25] Higher-level auth flow --- src/client/auth.ts | 150 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 150 insertions(+) diff --git a/src/client/auth.ts b/src/client/auth.ts index 57fed8f2..b5b043f6 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -79,6 +79,156 @@ export type OAuthTokens = z.infer; export type OAuthClientMetadata = z.infer; export type OAuthClientInformation = z.infer; +/** + * Implements an end-to-end OAuth client to be used with one MCP server. + * + * This client relies upon a concept of an authorized "session," the exact + * meaning of which is application-defined. Tokens, authorization codes, and + * code verifiers should not cross different sessions. + */ +export interface OAuthClientProvider { + /** + * The URL to redirect the user agent to after authorization. + * + * If the client is not redirecting to localhost, `clientInformation` must be + * implemented. + */ + get redirectUrl(): string | URL; + + /** + * Metadata about this OAuth client. + */ + get clientMetadata(): OAuthClientMetadata; + + /** + * Loads information about this OAuth client, as registered already with the + * server, or returns `undefined` if the client is not registered with the + * server. + * + * This method must be implemented _unless_ redirecting to `localhost`. + */ + clientInformation?(): OAuthClientInformation | undefined | Promise; + + /** + * If implemented, this permits the OAuth client to dynamically register with + * the server. Client information saved this way should later be read via + * `clientInformation()`. + * + * This method is not required to be implemented if redirecting to + * `localhost`, or if client information is statically known (e.g., + * pre-registered). + */ + saveClientInformation?(clientInformation: OAuthClientInformation): void | Promise; + + /** + * Loads any existing OAuth tokens for the current session, or returns + * `undefined` if there are no saved tokens. + */ + tokens(): OAuthTokens | undefined | Promise; + + /** + * Stores new OAuth tokens for the current session, after a successful + * authorization. + */ + saveTokens(tokens: OAuthTokens): void | Promise; + + /** + * Invoked to redirect the user agent to the given URL to begin the authorization flow. + */ + redirectToAuthorization(authorizationUrl: URL): void | Promise; + + /** + * Saves a PKCE code verifier for the current session, before redirecting to + * the authorization flow. + */ + saveCodeVerifier(codeVerifier: string): void | Promise; + + /** + * Loads the PKCE code verifier for the current session, necessary to validate + * the authorization result. + */ + codeVerifier(): string | Promise; +} + +export type AuthResult = "AUTHORIZED" | "REDIRECT"; + +/** + * Orchestrates the full auth flow with a server. + * + * This can be used as a single entry point for all authorization functionality, + * instead of linking together the other lower-level functions in this module. + */ +export async function auth( + provider: OAuthClientProvider, + { serverUrl, authorizationCode }: { serverUrl: string | URL, authorizationCode?: string }): Promise { + const metadata = await discoverOAuthMetadata(serverUrl); + + // Handle client registration if needed + const hostname = new URL(provider.redirectUrl).hostname; + if (hostname !== "localhost" && hostname !== "127.0.0.1") { + if (!provider.clientInformation) { + throw new Error("OAuth client information is required when not redirecting to localhost") + } + + let clientInformation = await Promise.resolve(provider.clientInformation()); + if (!clientInformation) { + if (authorizationCode !== undefined) { + throw new Error("Existing OAuth client information is required when exchanging an authorization code"); + } + + if (!provider.saveClientInformation) { + throw new Error("OAuth client information must be saveable when not provided and not redirecting to localhost"); + } + + clientInformation = await registerClient(serverUrl, { + metadata, + clientMetadata: provider.clientMetadata, + }); + + await provider.saveClientInformation(clientInformation); + } + + // TODO: Send clientInformation into auth flow + } + + // Exchange authorization code for tokens + if (authorizationCode !== undefined) { + const codeVerifier = await provider.codeVerifier(); + const tokens = await exchangeAuthorization(serverUrl, { + metadata, + authorizationCode, + codeVerifier, + }); + + await provider.saveTokens(tokens); + return "AUTHORIZED"; + } + + const tokens = await provider.tokens(); + + // Handle token refresh or new authorization + if (tokens?.refresh_token) { + try { + // Attempt to refresh the token + const newTokens = await refreshAuthorization(serverUrl, { + metadata, + refreshToken: tokens.refresh_token, + }); + + await provider.saveTokens(newTokens); + return "AUTHORIZED"; + } catch (error) { + console.error("Could not refresh OAuth tokens:", error); + } + } + + // Start new authorization flow + const { authorizationUrl, codeVerifier } = await startAuthorization(serverUrl, { metadata, redirectUrl: provider.redirectUrl }); + await provider.saveCodeVerifier(codeVerifier); + await provider.redirectToAuthorization(authorizationUrl); + return "REDIRECT"; +} + /** * Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata. * From 9fea0af99800647c7e9d0e7fca3fc42469cc8258 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 11 Feb 2025 15:59:16 +0000 Subject: [PATCH 13/25] Correctly provide client ID (+ secret) to every auth step --- src/client/auth.test.ts | 37 ++++++++++++++- src/client/auth.ts | 101 +++++++++++++++++++++++----------------- 2 files changed, 95 insertions(+), 43 deletions(-) diff --git a/src/client/auth.test.ts b/src/client/auth.test.ts index 39be8216..7332ae36 100644 --- a/src/client/auth.test.ts +++ b/src/client/auth.test.ts @@ -96,10 +96,18 @@ describe("OAuth Authorization", () => { code_challenge_methods_supported: ["S256"], }; + const validClientInfo = { + client_id: "client123", + client_secret: "secret123", + redirect_uris: ["http://localhost:3000/callback"], + client_name: "Test Client", + }; + it("generates authorization URL with PKCE challenge", async () => { const { authorizationUrl, codeVerifier } = await startAuthorization( "https://auth.example.com", { + clientInformation: validClientInfo, redirectUrl: "http://localhost:3000/callback", } ); @@ -123,6 +131,7 @@ describe("OAuth Authorization", () => { "https://auth.example.com", { metadata: validMetadata, + clientInformation: validClientInfo, redirectUrl: "http://localhost:3000/callback", } ); @@ -141,6 +150,7 @@ describe("OAuth Authorization", () => { await expect( startAuthorization("https://auth.example.com", { metadata, + clientInformation: validClientInfo, redirectUrl: "http://localhost:3000/callback", }) ).rejects.toThrow(/does not support response type/); @@ -156,6 +166,7 @@ describe("OAuth Authorization", () => { await expect( startAuthorization("https://auth.example.com", { metadata, + clientInformation: validClientInfo, redirectUrl: "http://localhost:3000/callback", }) ).rejects.toThrow(/does not support code challenge method/); @@ -170,6 +181,13 @@ describe("OAuth Authorization", () => { refresh_token: "refresh123", }; + const validClientInfo = { + client_id: "client123", + client_secret: "secret123", + redirect_uris: ["http://localhost:3000/callback"], + client_name: "Test Client", + }; + it("exchanges code for tokens", async () => { mockFetch.mockResolvedValueOnce({ ok: true, @@ -178,6 +196,7 @@ describe("OAuth Authorization", () => { }); const tokens = await exchangeAuthorization("https://auth.example.com", { + clientInformation: validClientInfo, authorizationCode: "code123", codeVerifier: "verifier123", }); @@ -199,6 +218,8 @@ describe("OAuth Authorization", () => { expect(body.get("grant_type")).toBe("authorization_code"); expect(body.get("code")).toBe("code123"); expect(body.get("code_verifier")).toBe("verifier123"); + expect(body.get("client_id")).toBe("client123"); + expect(body.get("client_secret")).toBe("secret123"); }); it("validates token response schema", async () => { @@ -213,6 +234,7 @@ describe("OAuth Authorization", () => { await expect( exchangeAuthorization("https://auth.example.com", { + clientInformation: validClientInfo, authorizationCode: "code123", codeVerifier: "verifier123", }) @@ -227,6 +249,7 @@ describe("OAuth Authorization", () => { await expect( exchangeAuthorization("https://auth.example.com", { + clientInformation: validClientInfo, authorizationCode: "code123", codeVerifier: "verifier123", }) @@ -242,6 +265,13 @@ describe("OAuth Authorization", () => { refresh_token: "newrefresh123", }; + const validClientInfo = { + client_id: "client123", + client_secret: "secret123", + redirect_uris: ["http://localhost:3000/callback"], + client_name: "Test Client", + }; + it("exchanges refresh token for new tokens", async () => { mockFetch.mockResolvedValueOnce({ ok: true, @@ -250,6 +280,7 @@ describe("OAuth Authorization", () => { }); const tokens = await refreshAuthorization("https://auth.example.com", { + clientInformation: validClientInfo, refreshToken: "refresh123", }); @@ -269,6 +300,8 @@ describe("OAuth Authorization", () => { const body = mockFetch.mock.calls[0][1].body as URLSearchParams; expect(body.get("grant_type")).toBe("refresh_token"); expect(body.get("refresh_token")).toBe("refresh123"); + expect(body.get("client_id")).toBe("client123"); + expect(body.get("client_secret")).toBe("secret123"); }); it("validates token response schema", async () => { @@ -283,6 +316,7 @@ describe("OAuth Authorization", () => { await expect( refreshAuthorization("https://auth.example.com", { + clientInformation: validClientInfo, refreshToken: "refresh123", }) ).rejects.toThrow(); @@ -296,9 +330,10 @@ describe("OAuth Authorization", () => { await expect( refreshAuthorization("https://auth.example.com", { + clientInformation: validClientInfo, refreshToken: "refresh123", }) - ).rejects.toThrow("Token exchange failed"); + ).rejects.toThrow("Token refresh failed"); }); }); diff --git a/src/client/auth.ts b/src/client/auth.ts index b5b043f6..6fef1439 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -89,9 +89,6 @@ export type OAuthClientInformation = z.infer; + clientInformation(): OAuthClientInformation | undefined | Promise; /** * If implemented, this permits the OAuth client to dynamically register with * the server. Client information saved this way should later be read via * `clientInformation()`. * - * This method is not required to be implemented if redirecting to - * `localhost`, or if client information is statically known (e.g., - * pre-registered). + * This method is not required to be implemented if client information is + * statically known (e.g., pre-registered). */ saveClientInformation?(clientInformation: OAuthClientInformation): void | Promise; @@ -164,31 +158,22 @@ export async function auth( const metadata = await discoverOAuthMetadata(serverUrl); // Handle client registration if needed - const hostname = new URL(provider.redirectUrl).hostname; - if (hostname !== "localhost" && hostname !== "127.0.0.1") { - if (!provider.clientInformation) { - throw new Error("OAuth client information is required when not redirecting to localhost") + let clientInformation = await Promise.resolve(provider.clientInformation()); + if (!clientInformation) { + if (authorizationCode !== undefined) { + throw new Error("Existing OAuth client information is required when exchanging an authorization code"); } - let clientInformation = await Promise.resolve(provider.clientInformation()); - if (!clientInformation) { - if (authorizationCode !== undefined) { - throw new Error("Existing OAuth client information is required when exchanging an authorization code"); - } - - if (!provider.saveClientInformation) { - throw new Error("OAuth client information must be saveable when not provided and not redirecting to localhost"); - } - - clientInformation = await registerClient(serverUrl, { - metadata, - clientMetadata: provider.clientMetadata, - }); - - await provider.saveClientInformation(clientInformation); + if (!provider.saveClientInformation) { + throw new Error("OAuth client information must be saveable for dynamic registration"); } - // TODO: Send clientInformation into auth flow + clientInformation = await registerClient(serverUrl, { + metadata, + clientMetadata: provider.clientMetadata, + }); + + await provider.saveClientInformation(clientInformation); } // Exchange authorization code for tokens @@ -196,6 +181,7 @@ export async function auth( const codeVerifier = await provider.codeVerifier(); const tokens = await exchangeAuthorization(serverUrl, { metadata, + clientInformation, authorizationCode, codeVerifier, }); @@ -212,6 +198,7 @@ export async function auth( // Attempt to refresh the token const newTokens = await refreshAuthorization(serverUrl, { metadata, + clientInformation, refreshToken: tokens.refresh_token, }); @@ -223,7 +210,12 @@ export async function auth( } // Start new authorization flow - const { authorizationUrl, codeVerifier } = await startAuthorization(serverUrl, { metadata, redirectUrl: provider.redirectUrl }); + const { authorizationUrl, codeVerifier } = await startAuthorization(serverUrl, { + metadata, + clientInformation, + redirectUrl: provider.redirectUrl + }); + await provider.saveCodeVerifier(codeVerifier); await provider.redirectToAuthorization(authorizationUrl); return "REDIRECT"; @@ -260,8 +252,13 @@ export async function startAuthorization( serverUrl: string | URL, { metadata, + clientInformation, redirectUrl, - }: { metadata?: OAuthMetadata; redirectUrl: string | URL }, + }: { + metadata?: OAuthMetadata; + clientInformation: OAuthClientInformation; + redirectUrl: string | URL; + }, ): Promise<{ authorizationUrl: URL; codeVerifier: string }> { const responseType = "code"; const codeChallengeMethod = "S256"; @@ -294,6 +291,7 @@ export async function startAuthorization( const codeChallenge = challenge.code_challenge; authorizationUrl.searchParams.set("response_type", responseType); + authorizationUrl.searchParams.set("client_id", clientInformation.client_id); authorizationUrl.searchParams.set("code_challenge", codeChallenge); authorizationUrl.searchParams.set( "code_challenge_method", @@ -311,10 +309,12 @@ export async function exchangeAuthorization( serverUrl: string | URL, { metadata, + clientInformation, authorizationCode, codeVerifier, }: { metadata?: OAuthMetadata; + clientInformation: OAuthClientInformation; authorizationCode: string; codeVerifier: string; }, @@ -338,16 +338,23 @@ export async function exchangeAuthorization( } // Exchange code for tokens + const params = new URLSearchParams({ + grant_type: grantType, + client_id: clientInformation.client_id, + code: authorizationCode, + code_verifier: codeVerifier, + }); + + if (clientInformation.client_secret) { + params.set("client_secret", clientInformation.client_secret); + } + const response = await fetch(tokenUrl, { method: "POST", headers: { "Content-Type": "application/x-www-form-urlencoded", }, - body: new URLSearchParams({ - grant_type: grantType, - code: authorizationCode, - code_verifier: codeVerifier, - }), + body: params, }); if (!response.ok) { @@ -364,9 +371,11 @@ export async function refreshAuthorization( serverUrl: string | URL, { metadata, + clientInformation, refreshToken, }: { metadata?: OAuthMetadata; + clientInformation: OAuthClientInformation; refreshToken: string; }, ): Promise { @@ -388,19 +397,27 @@ export async function refreshAuthorization( tokenUrl = new URL("/token", serverUrl); } + // Exchange refresh token + const params = new URLSearchParams({ + grant_type: grantType, + client_id: clientInformation.client_id, + refresh_token: refreshToken, + }); + + if (clientInformation.client_secret) { + params.set("client_secret", clientInformation.client_secret); + } + const response = await fetch(tokenUrl, { method: "POST", headers: { "Content-Type": "application/x-www-form-urlencoded", }, - body: new URLSearchParams({ - grant_type: grantType, - refresh_token: refreshToken, - }), + body: params, }); if (!response.ok) { - throw new Error(`Token exchange failed: HTTP ${response.status}`); + throw new Error(`Token refresh failed: HTTP ${response.status}`); } return OAuthTokensSchema.parse(await response.json()); From 8bff256cb289898b3f20ec120f902f1f6439c64c Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 11 Feb 2025 16:02:16 +0000 Subject: [PATCH 14/25] Preemptively bump package version --- package-lock.json | 4 ++-- package.json | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/package-lock.json b/package-lock.json index f09bdc2c..e4bbd079 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@modelcontextprotocol/sdk", - "version": "1.4.1", + "version": "1.5.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@modelcontextprotocol/sdk", - "version": "1.4.1", + "version": "1.5.0", "license": "MIT", "dependencies": { "content-type": "^1.0.5", diff --git a/package.json b/package.json index 66d28112..9558905d 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@modelcontextprotocol/sdk", - "version": "1.4.1", + "version": "1.5.0", "description": "Model Context Protocol implementation for TypeScript", "license": "MIT", "author": "Anthropic, PBC (https://anthropic.com)", @@ -74,4 +74,4 @@ "resolutions": { "strip-ansi": "6.0.1" } -} +} \ No newline at end of file From 141f3f5532120d1f479d640d777d42d9e87a3e9c Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 11 Feb 2025 16:54:57 +0000 Subject: [PATCH 15/25] Don't need full OAuth client metadata in response --- src/client/auth.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/client/auth.ts b/src/client/auth.ts index 6fef1439..8b6eed7c 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -71,7 +71,7 @@ export const OAuthClientInformationSchema = z.object({ client_secret: z.string().optional(), client_id_issued_at: z.number().optional(), client_secret_expires_at: z.number().optional(), -}).merge(OAuthClientMetadataSchema); +}).passthrough(); export type OAuthMetadata = z.infer; export type OAuthTokens = z.infer; From fceebdcf5a49837187b8436ccc323d1da4067287 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 11 Feb 2025 17:58:58 +0000 Subject: [PATCH 16/25] Add `MCP-Protocol-Version` header --- src/client/auth.ts | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/client/auth.ts b/src/client/auth.ts index 8b6eed7c..2c11ea8a 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -1,5 +1,6 @@ import pkceChallenge from "pkce-challenge"; import { z } from "zod"; +import { LATEST_PROTOCOL_VERSION } from "../types.js"; export const OAuthMetadataSchema = z .object({ @@ -229,9 +230,15 @@ export async function auth( */ export async function discoverOAuthMetadata( serverUrl: string | URL, + opts?: { protocolVersion?: string }, ): Promise { const url = new URL("/.well-known/oauth-authorization-server", serverUrl); - const response = await fetch(url); + const response = await fetch(url, { + headers: { + "MCP-Protocol-Version": opts?.protocolVersion ?? LATEST_PROTOCOL_VERSION + } + }); + if (response.status === 404) { return undefined; } From 03e583b70f685fc5c0751554f87b5441f7690c96 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 11 Feb 2025 20:35:36 +0000 Subject: [PATCH 17/25] Add support for AuthProvider in SSEClientTransport --- src/client/sse.ts | 91 ++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 83 insertions(+), 8 deletions(-) diff --git a/src/client/sse.ts b/src/client/sse.ts index 14921f57..173da8ab 100644 --- a/src/client/sse.ts +++ b/src/client/sse.ts @@ -1,6 +1,7 @@ import { EventSource, type ErrorEvent, type EventSourceInit } from "eventsource"; import { Transport } from "../shared/transport.js"; import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; +import { auth, AuthResult, OAuthClientProvider } from "./auth.js"; export class SseError extends Error { constructor( @@ -12,6 +13,34 @@ export class SseError extends Error { } } +/** + * Configuration options for the `SSEClientTransport`. + */ +export type SSEClientTransportOptions = { + /** + * An OAuth client provider to use for authentication. + * + * If given, the transport will automatically attach an `Authorization` header + * if an access token is available, or begin the authorization flow if not. + */ + authProvider?: OAuthClientProvider; + + /** + * Customizes the initial SSE request to the server (the request that begins the stream). + * + * NOTE: Setting this property will prevent an `Authorization` header from + * being automatically attached to the SSE request, if an `authProvider` is + * also given. This can be worked around by setting the `Authorization` header + * manually. + */ + eventSourceInit?: EventSourceInit; + + /** + * Customizes recurring POST requests to the server. + */ + requestInit?: RequestInit; +}; + /** * Client transport for SSE: this will connect to a server using Server-Sent Events for receiving * messages and make separate POST requests for sending messages. @@ -23,6 +52,7 @@ export class SSEClientTransport implements Transport { private _url: URL; private _eventSourceInit?: EventSourceInit; private _requestInit?: RequestInit; + private _authProvider?: OAuthClientProvider; onclose?: () => void; onerror?: (error: Error) => void; @@ -30,28 +60,62 @@ export class SSEClientTransport implements Transport { constructor( url: URL, - opts?: { eventSourceInit?: EventSourceInit; requestInit?: RequestInit }, + opts?: SSEClientTransportOptions, ) { this._url = url; this._eventSourceInit = opts?.eventSourceInit; this._requestInit = opts?.requestInit; + this._authProvider = opts?.authProvider; } - start(): Promise { - if (this._eventSource) { - throw new Error( - "SSEClientTransport already started! If using Client class, note that connect() calls start() automatically.", - ); + private async _authThenStart(): Promise { + if (!this._authProvider) { + throw new Error("No auth provider"); } + let result: AuthResult; + try { + result = await auth(this._authProvider, { serverUrl: this._url }); + } catch (error) { + this.onerror?.(error as Error); + throw error; + } + + if (result !== "AUTHORIZED") { + throw new Error("Unauthorized"); + } + + return await this._startOrAuth(); + } + + private async _commonHeaders(): Promise { + const headers: HeadersInit = {}; + if (this._authProvider) { + const tokens = await this._authProvider.tokens(); + if (tokens) { + headers["Authorization"] = `Bearer ${tokens.access_token}`; + } + } + + return headers; + } + + private _startOrAuth(): Promise { return new Promise((resolve, reject) => { this._eventSource = new EventSource( this._url.href, - this._eventSourceInit, + this._eventSourceInit ?? { + fetch: (url, init) => this._commonHeaders().then((headers) => fetch(url, { ...init, headers })), + }, ); this._abortController = new AbortController(); this._eventSource.onerror = (event) => { + if (event.code === 401 && this._authProvider) { + this._authThenStart().then(resolve, reject); + return; + } + const error = new SseError(event.code, event.message, event); reject(error); this.onerror?.(error); @@ -97,6 +161,16 @@ export class SSEClientTransport implements Transport { }); } + async start() { + if (this._eventSource) { + throw new Error( + "SSEClientTransport already started! If using Client class, note that connect() calls start() automatically.", + ); + } + + return await this._startOrAuth(); + } + async close(): Promise { this._abortController?.abort(); this._eventSource?.close(); @@ -109,7 +183,8 @@ export class SSEClientTransport implements Transport { } try { - const headers = new Headers(this._requestInit?.headers); + const commonHeaders = await this._commonHeaders(); + const headers = new Headers({ ...commonHeaders, ...this._requestInit?.headers }); headers.set("content-type", "application/json"); const init = { ...this._requestInit, From 92621516f8d15689d109f8ac127ecc8037583caf Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 11 Feb 2025 21:37:49 +0000 Subject: [PATCH 18/25] Re-auth upon 401 when POSTing --- src/client/sse.ts | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/client/sse.ts b/src/client/sse.ts index 173da8ab..e8de3e38 100644 --- a/src/client/sse.ts +++ b/src/client/sse.ts @@ -195,8 +195,17 @@ export class SSEClientTransport implements Transport { }; const response = await fetch(this._endpoint, init); - if (!response.ok) { + if (response.status === 401 && this._authProvider) { + const result = await auth(this._authProvider, { serverUrl: this._url }); + if (result !== "AUTHORIZED") { + throw new Error("Unauthorized"); + } + + // Purposely _not_ awaited, so we don't call onerror twice + return this.send(message); + } + const text = await response.text().catch(() => null); throw new Error( `Error POSTing to endpoint (HTTP ${response.status}): ${text}`, From 3e2dd35dcab440702dc9a94d9c32d42c8b2987fa Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 11 Feb 2025 22:02:20 +0000 Subject: [PATCH 19/25] Fix tests --- jest.config.js | 4 ++++ src/__mocks__/pkce-challenge.ts | 6 ++++++ src/client/auth.test.ts | 20 +++++++------------- src/client/sse.ts | 8 +++++++- tsconfig.cjs.json | 2 +- tsconfig.prod.json | 2 +- 6 files changed, 26 insertions(+), 16 deletions(-) create mode 100644 src/__mocks__/pkce-challenge.ts diff --git a/jest.config.js b/jest.config.js index 17ab1fbe..f8f621c8 100644 --- a/jest.config.js +++ b/jest.config.js @@ -7,6 +7,10 @@ export default { ...defaultEsmPreset, moduleNameMapper: { "^(\\.{1,2}/.*)\\.js$": "$1", + "^pkce-challenge$": "/src/__mocks__/pkce-challenge.ts" }, + transformIgnorePatterns: [ + "/node_modules/(?!eventsource)/" + ], testPathIgnorePatterns: ["/node_modules/", "/dist/"], }; diff --git a/src/__mocks__/pkce-challenge.ts b/src/__mocks__/pkce-challenge.ts new file mode 100644 index 00000000..10e13054 --- /dev/null +++ b/src/__mocks__/pkce-challenge.ts @@ -0,0 +1,6 @@ +export default function pkceChallenge() { + return { + code_verifier: "test_verifier", + code_challenge: "test_challenge", + }; +} \ No newline at end of file diff --git a/src/client/auth.test.ts b/src/client/auth.test.ts index 7332ae36..6d60806f 100644 --- a/src/client/auth.test.ts +++ b/src/client/auth.test.ts @@ -6,14 +6,6 @@ import { registerClient, } from "./auth.js"; -// Mock pkce-challenge -jest.mock("pkce-challenge", () => ({ - __esModule: true, - default: () => ({ - code_verifier: "test_verifier", - code_challenge: "test_challenge", - }), -})); // Mock fetch globally const mockFetch = jest.fn(); @@ -43,11 +35,13 @@ describe("OAuth Authorization", () => { const metadata = await discoverOAuthMetadata("https://auth.example.com"); expect(metadata).toEqual(validMetadata); - expect(mockFetch).toHaveBeenCalledWith( - expect.objectContaining({ - href: "https://auth.example.com/.well-known/oauth-authorization-server", - }) - ); + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(1); + const [url, options] = calls[0]; + expect(url.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server"); + expect(options.headers).toEqual({ + "MCP-Protocol-Version": "2024-11-05" + }); }); it("returns undefined when discovery endpoint returns 404", async () => { diff --git a/src/client/sse.ts b/src/client/sse.ts index e8de3e38..59eef4d3 100644 --- a/src/client/sse.ts +++ b/src/client/sse.ts @@ -105,7 +105,13 @@ export class SSEClientTransport implements Transport { this._eventSource = new EventSource( this._url.href, this._eventSourceInit ?? { - fetch: (url, init) => this._commonHeaders().then((headers) => fetch(url, { ...init, headers })), + fetch: (url, init) => this._commonHeaders().then((headers) => fetch(url, { + ...init, + headers: { + ...headers, + Accept: "text/event-stream" + } + })), }, ); this._abortController = new AbortController(); diff --git a/tsconfig.cjs.json b/tsconfig.cjs.json index 058a5d9a..b2f344a8 100644 --- a/tsconfig.cjs.json +++ b/tsconfig.cjs.json @@ -5,5 +5,5 @@ "moduleResolution": "node", "outDir": "./dist/cjs" }, - "exclude": ["**/*.test.ts"] + "exclude": ["**/*.test.ts", "src/__mocks__/**/*"] } diff --git a/tsconfig.prod.json b/tsconfig.prod.json index 2c68666e..2302dd84 100644 --- a/tsconfig.prod.json +++ b/tsconfig.prod.json @@ -3,5 +3,5 @@ "compilerOptions": { "outDir": "./dist/esm" }, - "exclude": ["**/*.test.ts"] + "exclude": ["**/*.test.ts", "src/__mocks__/**/*"] } From d47d88be9e527832e68548a4484d9b039a044ec7 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Wed, 12 Feb 2025 13:47:59 +0000 Subject: [PATCH 20/25] Auth tests for SSEClientTransport --- src/client/sse.test.ts | 177 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 177 insertions(+) diff --git a/src/client/sse.test.ts b/src/client/sse.test.ts index f59c45fe..04193fff 100644 --- a/src/client/sse.test.ts +++ b/src/client/sse.test.ts @@ -2,6 +2,7 @@ import { createServer, type IncomingMessage, type Server } from "http"; import { AddressInfo } from "net"; import { JSONRPCMessage } from "../types.js"; import { SSEClientTransport } from "./sse.js"; +import { auth, OAuthClientProvider } from "./auth.js"; describe("SSEClientTransport", () => { let server: Server; @@ -284,4 +285,180 @@ describe("SSEClientTransport", () => { expect(calledHeaders.get("content-type")).toBe("application/json"); }); }); + + describe("auth handling", () => { + let mockAuthProvider: jest.Mocked; + + beforeEach(() => { + mockAuthProvider = { + get redirectUrl() { return "http://localhost/callback"; }, + get clientMetadata() { return { redirect_uris: ["http://localhost/callback"] }; }, + clientInformation: jest.fn(() => ({ client_id: "test-client-id" })), + tokens: jest.fn(), + saveTokens: jest.fn(), + redirectToAuthorization: jest.fn(), + saveCodeVerifier: jest.fn(), + codeVerifier: jest.fn(), + }; + }); + + it("attaches auth header from provider on SSE connection", async () => { + mockAuthProvider.tokens.mockResolvedValue({ + access_token: "test-token", + token_type: "Bearer" + }); + + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider, + }); + + await transport.start(); + + expect(lastServerRequest.headers.authorization).toBe("Bearer test-token"); + expect(mockAuthProvider.tokens).toHaveBeenCalled(); + }); + + it("attaches auth header from provider on POST requests", async () => { + mockAuthProvider.tokens.mockResolvedValue({ + access_token: "test-token", + token_type: "Bearer" + }); + + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider, + }); + + await transport.start(); + + const message: JSONRPCMessage = { + jsonrpc: "2.0", + id: "1", + method: "test", + params: {}, + }; + + await transport.send(message); + + expect(lastServerRequest.headers.authorization).toBe("Bearer test-token"); + expect(mockAuthProvider.tokens).toHaveBeenCalled(); + }); + + it("attempts auth flow on 401 during SSE connection", async () => { + // Create server that returns 401s + server.close(); + await new Promise(resolve => server.on("close", resolve)); + + server = createServer((req, res) => { + lastServerRequest = req; + if (req.url !== "/") { + res.writeHead(404).end(); + } else { + res.writeHead(401).end(); + } + }); + + await new Promise(resolve => { + server.listen(0, "127.0.0.1", () => { + const addr = server.address() as AddressInfo; + baseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider, + }); + + await expect(() => transport.start()).rejects.toThrow("Unauthorized"); + expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1); + }); + + it("attempts auth flow on 401 during POST request", async () => { + // Create server that accepts SSE but returns 401 on POST + server.close(); + await new Promise(resolve => server.on("close", resolve)); + + server = createServer((req, res) => { + lastServerRequest = req; + + switch (req.method) { + case "GET": + if (req.url !== "/") { + res.writeHead(404).end(); + return; + } + + res.writeHead(200, { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }); + res.write("event: endpoint\n"); + res.write(`data: ${baseUrl.href}\n\n`); + break; + + case "POST": + res.writeHead(401); + res.end(); + break; + } + }); + + await new Promise(resolve => { + server.listen(0, "127.0.0.1", () => { + const addr = server.address() as AddressInfo; + baseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider, + }); + + await transport.start(); + + const message: JSONRPCMessage = { + jsonrpc: "2.0", + id: "1", + method: "test", + params: {}, + }; + + await expect(() => transport.send(message)).rejects.toThrow("Unauthorized"); + expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1); + }); + + it("respects custom headers when using auth provider", async () => { + mockAuthProvider.tokens.mockResolvedValue({ + access_token: "test-token", + token_type: "Bearer" + }); + + const customHeaders = { + "X-Custom-Header": "custom-value", + }; + + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider, + requestInit: { + headers: customHeaders, + }, + }); + + await transport.start(); + + const message: JSONRPCMessage = { + jsonrpc: "2.0", + id: "1", + method: "test", + params: {}, + }; + + await transport.send(message); + + expect(lastServerRequest.headers.authorization).toBe("Bearer test-token"); + expect(lastServerRequest.headers["x-custom-header"]).toBe("custom-value"); + }); + }); }); From 36999dcef7587161b68d5829988eb88481aa5bc4 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Wed, 12 Feb 2025 14:00:21 +0000 Subject: [PATCH 21/25] Fix `fetch` test polluting auth tests --- src/client/sse.test.ts | 77 +++++++++++++++++++++++------------------- 1 file changed, 42 insertions(+), 35 deletions(-) diff --git a/src/client/sse.test.ts b/src/client/sse.test.ts index 04193fff..8bae8528 100644 --- a/src/client/sse.test.ts +++ b/src/client/sse.test.ts @@ -60,6 +60,8 @@ describe("SSEClientTransport", () => { afterEach(async () => { await transport.close(); await server.close(); + + jest.clearAllMocks(); }); describe("connection handling", () => { @@ -73,8 +75,7 @@ describe("SSEClientTransport", () => { it("rejects if server returns non-200 status", async () => { // Create a server that returns 403 - server.close(); - await new Promise((resolve) => server.on("close", resolve)); + await server.close(); server = createServer((req, res) => { res.writeHead(403); @@ -252,37 +253,45 @@ describe("SSEClientTransport", () => { await transport.start(); - // Mock fetch for the message sending test - global.fetch = jest.fn().mockResolvedValue({ - ok: true, - }); - - const message: JSONRPCMessage = { - jsonrpc: "2.0", - id: "1", - method: "test", - params: {}, - }; - - await transport.send(message); + // Store original fetch + const originalFetch = global.fetch; - // Verify fetch was called with correct headers - expect(global.fetch).toHaveBeenCalledWith( - expect.any(URL), - expect.objectContaining({ - headers: expect.any(Headers), - }), - ); + try { + // Mock fetch for the message sending test + global.fetch = jest.fn().mockResolvedValue({ + ok: true, + }); - const calledHeaders = (global.fetch as jest.Mock).mock.calls[0][1] - .headers; - expect(calledHeaders.get("Authorization")).toBe( - customHeaders.Authorization, - ); - expect(calledHeaders.get("X-Custom-Header")).toBe( - customHeaders["X-Custom-Header"], - ); - expect(calledHeaders.get("content-type")).toBe("application/json"); + const message: JSONRPCMessage = { + jsonrpc: "2.0", + id: "1", + method: "test", + params: {}, + }; + + await transport.send(message); + + // Verify fetch was called with correct headers + expect(global.fetch).toHaveBeenCalledWith( + expect.any(URL), + expect.objectContaining({ + headers: expect.any(Headers), + }), + ); + + const calledHeaders = (global.fetch as jest.Mock).mock.calls[0][1] + .headers; + expect(calledHeaders.get("Authorization")).toBe( + customHeaders.Authorization, + ); + expect(calledHeaders.get("X-Custom-Header")).toBe( + customHeaders["X-Custom-Header"], + ); + expect(calledHeaders.get("content-type")).toBe("application/json"); + } finally { + // Restore original fetch + global.fetch = originalFetch; + } }); }); @@ -345,8 +354,7 @@ describe("SSEClientTransport", () => { it("attempts auth flow on 401 during SSE connection", async () => { // Create server that returns 401s - server.close(); - await new Promise(resolve => server.on("close", resolve)); + await server.close(); server = createServer((req, res) => { lastServerRequest = req; @@ -375,8 +383,7 @@ describe("SSEClientTransport", () => { it("attempts auth flow on 401 during POST request", async () => { // Create server that accepts SSE but returns 401 on POST - server.close(); - await new Promise(resolve => server.on("close", resolve)); + await server.close(); server = createServer((req, res) => { lastServerRequest = req; From 79d2db3e71af1795dbccfa81180d3dff2937330e Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Wed, 12 Feb 2025 14:01:22 +0000 Subject: [PATCH 22/25] Missed one `await server.close()` --- src/client/sse.test.ts | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/client/sse.test.ts b/src/client/sse.test.ts index 8bae8528..539d742a 100644 --- a/src/client/sse.test.ts +++ b/src/client/sse.test.ts @@ -176,8 +176,7 @@ describe("SSEClientTransport", () => { it("handles POST request failures", async () => { // Create a server that returns 500 for POST - server.close(); - await new Promise((resolve) => server.on("close", resolve)); + await server.close(); server = createServer((req, res) => { if (req.method === "GET") { From c0ebe9015488685aee1cf6972ef96795ce429bb6 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Wed, 12 Feb 2025 14:17:03 +0000 Subject: [PATCH 23/25] Test transparent token refresh in SSEClientTransport --- src/client/sse.test.ts | 256 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 254 insertions(+), 2 deletions(-) diff --git a/src/client/sse.test.ts b/src/client/sse.test.ts index 539d742a..b5659d51 100644 --- a/src/client/sse.test.ts +++ b/src/client/sse.test.ts @@ -2,7 +2,7 @@ import { createServer, type IncomingMessage, type Server } from "http"; import { AddressInfo } from "net"; import { JSONRPCMessage } from "../types.js"; import { SSEClientTransport } from "./sse.js"; -import { auth, OAuthClientProvider } from "./auth.js"; +import { OAuthClientProvider, OAuthTokens } from "./auth.js"; describe("SSEClientTransport", () => { let server: Server; @@ -301,7 +301,7 @@ describe("SSEClientTransport", () => { mockAuthProvider = { get redirectUrl() { return "http://localhost/callback"; }, get clientMetadata() { return { redirect_uris: ["http://localhost/callback"] }; }, - clientInformation: jest.fn(() => ({ client_id: "test-client-id" })), + clientInformation: jest.fn(() => ({ client_id: "test-client-id", client_secret: "test-client-secret" })), tokens: jest.fn(), saveTokens: jest.fn(), redirectToAuthorization: jest.fn(), @@ -466,5 +466,257 @@ describe("SSEClientTransport", () => { expect(lastServerRequest.headers.authorization).toBe("Bearer test-token"); expect(lastServerRequest.headers["x-custom-header"]).toBe("custom-value"); }); + + it("refreshes expired token during SSE connection", async () => { + // Mock tokens() to return expired token until saveTokens is called + let currentTokens: OAuthTokens = { + access_token: "expired-token", + token_type: "Bearer", + refresh_token: "refresh-token" + }; + mockAuthProvider.tokens.mockImplementation(() => currentTokens); + mockAuthProvider.saveTokens.mockImplementation((tokens) => { + currentTokens = tokens; + }); + + // Create server that returns 401 for expired token, then accepts new token + await server.close(); + + let connectionAttempts = 0; + server = createServer((req, res) => { + lastServerRequest = req; + + if (req.url === "/token" && req.method === "POST") { + // Handle token refresh request + let body = ""; + req.on("data", chunk => { body += chunk; }); + req.on("end", () => { + const params = new URLSearchParams(body); + if (params.get("grant_type") === "refresh_token" && + params.get("refresh_token") === "refresh-token" && + params.get("client_id") === "test-client-id" && + params.get("client_secret") === "test-client-secret") { + res.writeHead(200, { "Content-Type": "application/json" }); + res.end(JSON.stringify({ + access_token: "new-token", + token_type: "Bearer", + refresh_token: "new-refresh-token" + })); + } else { + res.writeHead(400).end(); + } + }); + return; + } + + if (req.url !== "/") { + res.writeHead(404).end(); + return; + } + + const auth = req.headers.authorization; + if (auth === "Bearer expired-token") { + res.writeHead(401).end(); + return; + } + + if (auth === "Bearer new-token") { + res.writeHead(200, { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }); + res.write("event: endpoint\n"); + res.write(`data: ${baseUrl.href}\n\n`); + connectionAttempts++; + return; + } + + res.writeHead(401).end(); + }); + + await new Promise(resolve => { + server.listen(0, "127.0.0.1", () => { + const addr = server.address() as AddressInfo; + baseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider, + }); + + await transport.start(); + + expect(mockAuthProvider.saveTokens).toHaveBeenCalledWith({ + access_token: "new-token", + token_type: "Bearer", + refresh_token: "new-refresh-token" + }); + expect(connectionAttempts).toBe(1); + expect(lastServerRequest.headers.authorization).toBe("Bearer new-token"); + }); + + it("refreshes expired token during POST request", async () => { + // Mock tokens() to return expired token until saveTokens is called + let currentTokens: OAuthTokens = { + access_token: "expired-token", + token_type: "Bearer", + refresh_token: "refresh-token" + }; + mockAuthProvider.tokens.mockImplementation(() => currentTokens); + mockAuthProvider.saveTokens.mockImplementation((tokens) => { + currentTokens = tokens; + }); + + // Create server that accepts SSE but returns 401 on POST with expired token + await server.close(); + + let postAttempts = 0; + server = createServer((req, res) => { + lastServerRequest = req; + + if (req.url === "/token" && req.method === "POST") { + // Handle token refresh request + let body = ""; + req.on("data", chunk => { body += chunk; }); + req.on("end", () => { + const params = new URLSearchParams(body); + if (params.get("grant_type") === "refresh_token" && + params.get("refresh_token") === "refresh-token" && + params.get("client_id") === "test-client-id" && + params.get("client_secret") === "test-client-secret") { + res.writeHead(200, { "Content-Type": "application/json" }); + res.end(JSON.stringify({ + access_token: "new-token", + token_type: "Bearer", + refresh_token: "new-refresh-token" + })); + } else { + res.writeHead(400).end(); + } + }); + return; + } + + switch (req.method) { + case "GET": + if (req.url !== "/") { + res.writeHead(404).end(); + return; + } + + res.writeHead(200, { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }); + res.write("event: endpoint\n"); + res.write(`data: ${baseUrl.href}\n\n`); + break; + + case "POST": { + if (req.url !== "/") { + res.writeHead(404).end(); + return; + } + + const auth = req.headers.authorization; + if (auth === "Bearer expired-token") { + res.writeHead(401).end(); + return; + } + + if (auth === "Bearer new-token") { + res.writeHead(200).end(); + postAttempts++; + return; + } + + res.writeHead(401).end(); + break; + } + } + }); + + await new Promise(resolve => { + server.listen(0, "127.0.0.1", () => { + const addr = server.address() as AddressInfo; + baseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider, + }); + + await transport.start(); + + const message: JSONRPCMessage = { + jsonrpc: "2.0", + id: "1", + method: "test", + params: {}, + }; + + await transport.send(message); + + expect(mockAuthProvider.saveTokens).toHaveBeenCalledWith({ + access_token: "new-token", + token_type: "Bearer", + refresh_token: "new-refresh-token" + }); + expect(postAttempts).toBe(1); + expect(lastServerRequest.headers.authorization).toBe("Bearer new-token"); + }); + + it("redirects to authorization if refresh token flow fails", async () => { + // Mock tokens() to return expired token until saveTokens is called + let currentTokens: OAuthTokens = { + access_token: "expired-token", + token_type: "Bearer", + refresh_token: "refresh-token" + }; + mockAuthProvider.tokens.mockImplementation(() => currentTokens); + mockAuthProvider.saveTokens.mockImplementation((tokens) => { + currentTokens = tokens; + }); + + // Create server that returns 401 for all tokens + await server.close(); + + server = createServer((req, res) => { + lastServerRequest = req; + + if (req.url === "/token" && req.method === "POST") { + // Handle token refresh request - always fail + res.writeHead(400).end(); + return; + } + + if (req.url !== "/") { + res.writeHead(404).end(); + return; + } + res.writeHead(401).end(); + }); + + await new Promise(resolve => { + server.listen(0, "127.0.0.1", () => { + const addr = server.address() as AddressInfo; + baseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider, + }); + + await expect(transport.start()).rejects.toThrow("Unauthorized"); + expect(mockAuthProvider.redirectToAuthorization).toHaveBeenCalled(); + }); }); }); From 44a440854f5135ea684110b0e63f208c90095da8 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Mon, 17 Feb 2025 10:34:57 +0000 Subject: [PATCH 24/25] Code review comments --- src/client/auth.test.ts | 6 +++--- src/client/auth.ts | 7 ------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/src/client/auth.test.ts b/src/client/auth.test.ts index 6d60806f..c65a5f3a 100644 --- a/src/client/auth.test.ts +++ b/src/client/auth.test.ts @@ -84,8 +84,8 @@ describe("OAuth Authorization", () => { describe("startAuthorization", () => { const validMetadata = { issuer: "https://auth.example.com", - authorization_endpoint: "https://auth.example.com/authorize", - token_endpoint: "https://auth.example.com/token", + authorization_endpoint: "https://auth.example.com/auth", + token_endpoint: "https://auth.example.com/tkn", response_types_supported: ["code"], code_challenge_methods_supported: ["S256"], }; @@ -131,7 +131,7 @@ describe("OAuth Authorization", () => { ); expect(authorizationUrl.toString()).toMatch( - /^https:\/\/auth\.example\.com\/authorize\?/ + /^https:\/\/auth\.example\.com\/auth\?/ ); }); diff --git a/src/client/auth.ts b/src/client/auth.ts index 2c11ea8a..e38b9bfc 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -432,13 +432,6 @@ export async function refreshAuthorization( /** * Performs OAuth 2.0 Dynamic Client Registration according to RFC 7591. - * - * @param serverUrl - The base URL of the authorization server - * @param options - Registration options - * @param options.metadata - OAuth server metadata containing the registration endpoint - * @param options.clientMetadata - Client metadata for registration - * @returns The registered client information - * @throws Error if the server doesn't support dynamic registration or if registration fails */ export async function registerClient( serverUrl: string | URL, From b1f4b65709215c05a92cdb5815ceb9b9574a13e2 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Mon, 17 Feb 2025 10:45:34 +0000 Subject: [PATCH 25/25] Improve auth docs and add `finishAuth` convenience method --- src/client/auth.ts | 6 ++++++ src/client/sse.test.ts | 26 +++++++++++++------------- src/client/sse.ts | 34 ++++++++++++++++++++++++++++------ 3 files changed, 47 insertions(+), 19 deletions(-) diff --git a/src/client/auth.ts b/src/client/auth.ts index e38b9bfc..b30134b8 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -147,6 +147,12 @@ export interface OAuthClientProvider { export type AuthResult = "AUTHORIZED" | "REDIRECT"; +export class UnauthorizedError extends Error { + constructor(message?: string) { + super(message ?? "Unauthorized"); + } +} + /** * Orchestrates the full auth flow with a server. * diff --git a/src/client/sse.test.ts b/src/client/sse.test.ts index b5659d51..57497013 100644 --- a/src/client/sse.test.ts +++ b/src/client/sse.test.ts @@ -2,7 +2,7 @@ import { createServer, type IncomingMessage, type Server } from "http"; import { AddressInfo } from "net"; import { JSONRPCMessage } from "../types.js"; import { SSEClientTransport } from "./sse.js"; -import { OAuthClientProvider, OAuthTokens } from "./auth.js"; +import { OAuthClientProvider, OAuthTokens, UnauthorizedError } from "./auth.js"; describe("SSEClientTransport", () => { let server: Server; @@ -376,7 +376,7 @@ describe("SSEClientTransport", () => { authProvider: mockAuthProvider, }); - await expect(() => transport.start()).rejects.toThrow("Unauthorized"); + await expect(() => transport.start()).rejects.toThrow(UnauthorizedError); expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1); }); @@ -431,7 +431,7 @@ describe("SSEClientTransport", () => { params: {}, }; - await expect(() => transport.send(message)).rejects.toThrow("Unauthorized"); + await expect(() => transport.send(message)).rejects.toThrow(UnauthorizedError); expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1); }); @@ -485,17 +485,17 @@ describe("SSEClientTransport", () => { let connectionAttempts = 0; server = createServer((req, res) => { lastServerRequest = req; - + if (req.url === "/token" && req.method === "POST") { // Handle token refresh request let body = ""; req.on("data", chunk => { body += chunk; }); req.on("end", () => { const params = new URLSearchParams(body); - if (params.get("grant_type") === "refresh_token" && - params.get("refresh_token") === "refresh-token" && - params.get("client_id") === "test-client-id" && - params.get("client_secret") === "test-client-secret") { + if (params.get("grant_type") === "refresh_token" && + params.get("refresh_token") === "refresh-token" && + params.get("client_id") === "test-client-id" && + params.get("client_secret") === "test-client-secret") { res.writeHead(200, { "Content-Type": "application/json" }); res.end(JSON.stringify({ access_token: "new-token", @@ -583,10 +583,10 @@ describe("SSEClientTransport", () => { req.on("data", chunk => { body += chunk; }); req.on("end", () => { const params = new URLSearchParams(body); - if (params.get("grant_type") === "refresh_token" && - params.get("refresh_token") === "refresh-token" && - params.get("client_id") === "test-client-id" && - params.get("client_secret") === "test-client-secret") { + if (params.get("grant_type") === "refresh_token" && + params.get("refresh_token") === "refresh-token" && + params.get("client_id") === "test-client-id" && + params.get("client_secret") === "test-client-secret") { res.writeHead(200, { "Content-Type": "application/json" }); res.end(JSON.stringify({ access_token: "new-token", @@ -715,7 +715,7 @@ describe("SSEClientTransport", () => { authProvider: mockAuthProvider, }); - await expect(transport.start()).rejects.toThrow("Unauthorized"); + await expect(() => transport.start()).rejects.toThrow(UnauthorizedError); expect(mockAuthProvider.redirectToAuthorization).toHaveBeenCalled(); }); }); diff --git a/src/client/sse.ts b/src/client/sse.ts index 59eef4d3..5e9f0cf0 100644 --- a/src/client/sse.ts +++ b/src/client/sse.ts @@ -1,7 +1,7 @@ import { EventSource, type ErrorEvent, type EventSourceInit } from "eventsource"; import { Transport } from "../shared/transport.js"; import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; -import { auth, AuthResult, OAuthClientProvider } from "./auth.js"; +import { auth, AuthResult, OAuthClientProvider, UnauthorizedError } from "./auth.js"; export class SseError extends Error { constructor( @@ -20,8 +20,16 @@ export type SSEClientTransportOptions = { /** * An OAuth client provider to use for authentication. * - * If given, the transport will automatically attach an `Authorization` header - * if an access token is available, or begin the authorization flow if not. + * When an `authProvider` is specified and the SSE connection is started: + * 1. The connection is attempted with any existing access token from the `authProvider`. + * 2. If the access token has expired, the `authProvider` is used to refresh the token. + * 3. If token refresh fails or no access token exists, and auth is required, `OAuthClientProvider.redirectToAuthorization` is called, and an `UnauthorizedError` will be thrown from `connect`/`start`. + * + * After the user has finished authorizing via their user agent, and is redirected back to the MCP client application, call `SSEClientTransport.finishAuth` with the authorization code before retrying the connection. + * + * If an `authProvider` is not provided, and auth is required, an `UnauthorizedError` will be thrown. + * + * `UnauthorizedError` might also be thrown when sending any message over the SSE transport, indicating that the session has expired, and needs to be re-authed and reconnected. */ authProvider?: OAuthClientProvider; @@ -70,7 +78,7 @@ export class SSEClientTransport implements Transport { private async _authThenStart(): Promise { if (!this._authProvider) { - throw new Error("No auth provider"); + throw new UnauthorizedError("No auth provider"); } let result: AuthResult; @@ -82,7 +90,7 @@ export class SSEClientTransport implements Transport { } if (result !== "AUTHORIZED") { - throw new Error("Unauthorized"); + throw new UnauthorizedError(); } return await this._startOrAuth(); @@ -177,6 +185,20 @@ export class SSEClientTransport implements Transport { return await this._startOrAuth(); } + /** + * Call this method after the user has finished authorizing via their user agent and is redirected back to the MCP client application. This will exchange the authorization code for an access token, enabling the next connection attempt to successfully auth. + */ + async finishAuth(authorizationCode: string): Promise { + if (!this._authProvider) { + throw new UnauthorizedError("No auth provider"); + } + + const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode }); + if (result !== "AUTHORIZED") { + throw new UnauthorizedError("Failed to authorize"); + } + } + async close(): Promise { this._abortController?.abort(); this._eventSource?.close(); @@ -205,7 +227,7 @@ export class SSEClientTransport implements Transport { if (response.status === 401 && this._authProvider) { const result = await auth(this._authProvider, { serverUrl: this._url }); if (result !== "AUTHORIZED") { - throw new Error("Unauthorized"); + throw new UnauthorizedError(); } // Purposely _not_ awaited, so we don't call onerror twice