diff --git a/jest.config.js b/jest.config.js index 17ab1fbe..f8f621c8 100644 --- a/jest.config.js +++ b/jest.config.js @@ -7,6 +7,10 @@ export default { ...defaultEsmPreset, moduleNameMapper: { "^(\\.{1,2}/.*)\\.js$": "$1", + "^pkce-challenge$": "/src/__mocks__/pkce-challenge.ts" }, + transformIgnorePatterns: [ + "/node_modules/(?!eventsource)/" + ], testPathIgnorePatterns: ["/node_modules/", "/dist/"], }; diff --git a/package-lock.json b/package-lock.json index 687a7c0c..e4bbd079 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,16 +1,17 @@ { "name": "@modelcontextprotocol/sdk", - "version": "1.3.2", + "version": "1.5.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@modelcontextprotocol/sdk", - "version": "1.3.2", + "version": "1.5.0", "license": "MIT", "dependencies": { "content-type": "^1.0.5", "eventsource": "^3.0.2", + "pkce-challenge": "^4.1.0", "raw-body": "^3.0.0", "zod": "^3.23.8", "zod-to-json-schema": "^3.24.1" @@ -5058,6 +5059,15 @@ "node": ">= 6" } }, + "node_modules/pkce-challenge": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/pkce-challenge/-/pkce-challenge-4.1.0.tgz", + "integrity": "sha512-ZBmhE1C9LcPoH9XZSdwiPtbPHZROwAnMy+kIFQVrnMCxY4Cudlz3gBOpzilgc0jOgRaiT3sIWfpMomW2ar2orQ==", + "license": "MIT", + "engines": { + "node": ">=16.20.0" + } + }, "node_modules/pkg-dir": { "version": "4.2.0", "resolved": "https://registry.npmjs.org/pkg-dir/-/pkg-dir-4.2.0.tgz", diff --git a/package.json b/package.json index 77c268ed..9558905d 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@modelcontextprotocol/sdk", - "version": "1.4.1", + "version": "1.5.0", "description": "Model Context Protocol implementation for TypeScript", "license": "MIT", "author": "Anthropic, PBC (https://anthropic.com)", @@ -48,6 +48,7 @@ "dependencies": { "content-type": "^1.0.5", "eventsource": "^3.0.2", + "pkce-challenge": "^4.1.0", "raw-body": "^3.0.0", "zod": "^3.23.8", "zod-to-json-schema": "^3.24.1" @@ -73,4 +74,4 @@ "resolutions": { "strip-ansi": "6.0.1" } -} +} \ No newline at end of file diff --git a/src/__mocks__/pkce-challenge.ts b/src/__mocks__/pkce-challenge.ts new file mode 100644 index 00000000..10e13054 --- /dev/null +++ b/src/__mocks__/pkce-challenge.ts @@ -0,0 +1,6 @@ +export default function pkceChallenge() { + return { + code_verifier: "test_verifier", + code_challenge: "test_challenge", + }; +} \ No newline at end of file diff --git a/src/client/auth.test.ts b/src/client/auth.test.ts new file mode 100644 index 00000000..c65a5f3a --- /dev/null +++ b/src/client/auth.test.ts @@ -0,0 +1,420 @@ +import { + discoverOAuthMetadata, + startAuthorization, + exchangeAuthorization, + refreshAuthorization, + registerClient, +} from "./auth.js"; + + +// Mock fetch globally +const mockFetch = jest.fn(); +global.fetch = mockFetch; + +describe("OAuth Authorization", () => { + beforeEach(() => { + mockFetch.mockReset(); + }); + + describe("discoverOAuthMetadata", () => { + const validMetadata = { + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + registration_endpoint: "https://auth.example.com/register", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }; + + it("returns metadata when discovery succeeds", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validMetadata, + }); + + const metadata = await discoverOAuthMetadata("https://auth.example.com"); + expect(metadata).toEqual(validMetadata); + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(1); + const [url, options] = calls[0]; + expect(url.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server"); + expect(options.headers).toEqual({ + "MCP-Protocol-Version": "2024-11-05" + }); + }); + + it("returns undefined when discovery endpoint returns 404", async () => { + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404, + }); + + const metadata = await discoverOAuthMetadata("https://auth.example.com"); + expect(metadata).toBeUndefined(); + }); + + it("throws on non-404 errors", async () => { + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 500, + }); + + await expect( + discoverOAuthMetadata("https://auth.example.com") + ).rejects.toThrow("HTTP 500"); + }); + + it("validates metadata schema", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + // Missing required fields + issuer: "https://auth.example.com", + }), + }); + + await expect( + discoverOAuthMetadata("https://auth.example.com") + ).rejects.toThrow(); + }); + }); + + describe("startAuthorization", () => { + const validMetadata = { + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/auth", + token_endpoint: "https://auth.example.com/tkn", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }; + + const validClientInfo = { + client_id: "client123", + client_secret: "secret123", + redirect_uris: ["http://localhost:3000/callback"], + client_name: "Test Client", + }; + + it("generates authorization URL with PKCE challenge", async () => { + const { authorizationUrl, codeVerifier } = await startAuthorization( + "https://auth.example.com", + { + clientInformation: validClientInfo, + redirectUrl: "http://localhost:3000/callback", + } + ); + + expect(authorizationUrl.toString()).toMatch( + /^https:\/\/auth\.example\.com\/authorize\?/ + ); + expect(authorizationUrl.searchParams.get("response_type")).toBe("code"); + expect(authorizationUrl.searchParams.get("code_challenge")).toBe("test_challenge"); + expect(authorizationUrl.searchParams.get("code_challenge_method")).toBe( + "S256" + ); + expect(authorizationUrl.searchParams.get("redirect_uri")).toBe( + "http://localhost:3000/callback" + ); + expect(codeVerifier).toBe("test_verifier"); + }); + + it("uses metadata authorization_endpoint when provided", async () => { + const { authorizationUrl } = await startAuthorization( + "https://auth.example.com", + { + metadata: validMetadata, + clientInformation: validClientInfo, + redirectUrl: "http://localhost:3000/callback", + } + ); + + expect(authorizationUrl.toString()).toMatch( + /^https:\/\/auth\.example\.com\/auth\?/ + ); + }); + + it("validates response type support", async () => { + const metadata = { + ...validMetadata, + response_types_supported: ["token"], // Does not support 'code' + }; + + await expect( + startAuthorization("https://auth.example.com", { + metadata, + clientInformation: validClientInfo, + redirectUrl: "http://localhost:3000/callback", + }) + ).rejects.toThrow(/does not support response type/); + }); + + it("validates PKCE support", async () => { + const metadata = { + ...validMetadata, + response_types_supported: ["code"], + code_challenge_methods_supported: ["plain"], // Does not support 'S256' + }; + + await expect( + startAuthorization("https://auth.example.com", { + metadata, + clientInformation: validClientInfo, + redirectUrl: "http://localhost:3000/callback", + }) + ).rejects.toThrow(/does not support code challenge method/); + }); + }); + + describe("exchangeAuthorization", () => { + const validTokens = { + access_token: "access123", + token_type: "Bearer", + expires_in: 3600, + refresh_token: "refresh123", + }; + + const validClientInfo = { + client_id: "client123", + client_secret: "secret123", + redirect_uris: ["http://localhost:3000/callback"], + client_name: "Test Client", + }; + + it("exchanges code for tokens", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens, + }); + + const tokens = await exchangeAuthorization("https://auth.example.com", { + clientInformation: validClientInfo, + authorizationCode: "code123", + codeVerifier: "verifier123", + }); + + expect(tokens).toEqual(validTokens); + expect(mockFetch).toHaveBeenCalledWith( + expect.objectContaining({ + href: "https://auth.example.com/token", + }), + expect.objectContaining({ + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + }) + ); + + const body = mockFetch.mock.calls[0][1].body as URLSearchParams; + expect(body.get("grant_type")).toBe("authorization_code"); + expect(body.get("code")).toBe("code123"); + expect(body.get("code_verifier")).toBe("verifier123"); + expect(body.get("client_id")).toBe("client123"); + expect(body.get("client_secret")).toBe("secret123"); + }); + + it("validates token response schema", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + // Missing required fields + access_token: "access123", + }), + }); + + await expect( + exchangeAuthorization("https://auth.example.com", { + clientInformation: validClientInfo, + authorizationCode: "code123", + codeVerifier: "verifier123", + }) + ).rejects.toThrow(); + }); + + it("throws on error response", async () => { + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 400, + }); + + await expect( + exchangeAuthorization("https://auth.example.com", { + clientInformation: validClientInfo, + authorizationCode: "code123", + codeVerifier: "verifier123", + }) + ).rejects.toThrow("Token exchange failed"); + }); + }); + + describe("refreshAuthorization", () => { + const validTokens = { + access_token: "newaccess123", + token_type: "Bearer", + expires_in: 3600, + refresh_token: "newrefresh123", + }; + + const validClientInfo = { + client_id: "client123", + client_secret: "secret123", + redirect_uris: ["http://localhost:3000/callback"], + client_name: "Test Client", + }; + + it("exchanges refresh token for new tokens", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validTokens, + }); + + const tokens = await refreshAuthorization("https://auth.example.com", { + clientInformation: validClientInfo, + refreshToken: "refresh123", + }); + + expect(tokens).toEqual(validTokens); + expect(mockFetch).toHaveBeenCalledWith( + expect.objectContaining({ + href: "https://auth.example.com/token", + }), + expect.objectContaining({ + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + }) + ); + + const body = mockFetch.mock.calls[0][1].body as URLSearchParams; + expect(body.get("grant_type")).toBe("refresh_token"); + expect(body.get("refresh_token")).toBe("refresh123"); + expect(body.get("client_id")).toBe("client123"); + expect(body.get("client_secret")).toBe("secret123"); + }); + + it("validates token response schema", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + // Missing required fields + access_token: "newaccess123", + }), + }); + + await expect( + refreshAuthorization("https://auth.example.com", { + clientInformation: validClientInfo, + refreshToken: "refresh123", + }) + ).rejects.toThrow(); + }); + + it("throws on error response", async () => { + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 400, + }); + + await expect( + refreshAuthorization("https://auth.example.com", { + clientInformation: validClientInfo, + refreshToken: "refresh123", + }) + ).rejects.toThrow("Token refresh failed"); + }); + }); + + describe("registerClient", () => { + const validClientMetadata = { + redirect_uris: ["http://localhost:3000/callback"], + client_name: "Test Client", + }; + + const validClientInfo = { + client_id: "client123", + client_secret: "secret123", + client_id_issued_at: 1612137600, + client_secret_expires_at: 1612224000, + ...validClientMetadata, + }; + + it("registers client and returns client information", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => validClientInfo, + }); + + const clientInfo = await registerClient("https://auth.example.com", { + clientMetadata: validClientMetadata, + }); + + expect(clientInfo).toEqual(validClientInfo); + expect(mockFetch).toHaveBeenCalledWith( + expect.objectContaining({ + href: "https://auth.example.com/register", + }), + expect.objectContaining({ + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(validClientMetadata), + }) + ); + }); + + it("validates client information response schema", async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + // Missing required fields + client_secret: "secret123", + }), + }); + + await expect( + registerClient("https://auth.example.com", { + clientMetadata: validClientMetadata, + }) + ).rejects.toThrow(); + }); + + it("throws when registration endpoint not available in metadata", async () => { + const metadata = { + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + response_types_supported: ["code"], + }; + + await expect( + registerClient("https://auth.example.com", { + metadata, + clientMetadata: validClientMetadata, + }) + ).rejects.toThrow(/does not support dynamic client registration/); + }); + + it("throws on error response", async () => { + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 400, + }); + + await expect( + registerClient("https://auth.example.com", { + clientMetadata: validClientMetadata, + }) + ).rejects.toThrow("Dynamic client registration failed"); + }); + }); +}); \ No newline at end of file diff --git a/src/client/auth.ts b/src/client/auth.ts new file mode 100644 index 00000000..b30134b8 --- /dev/null +++ b/src/client/auth.ts @@ -0,0 +1,477 @@ +import pkceChallenge from "pkce-challenge"; +import { z } from "zod"; +import { LATEST_PROTOCOL_VERSION } from "../types.js"; + +export const OAuthMetadataSchema = z + .object({ + issuer: z.string(), + authorization_endpoint: z.string(), + token_endpoint: z.string(), + registration_endpoint: z.string().optional(), + scopes_supported: z.array(z.string()).optional(), + response_types_supported: z.array(z.string()), + response_modes_supported: z.array(z.string()).optional(), + grant_types_supported: z.array(z.string()).optional(), + token_endpoint_auth_methods_supported: z.array(z.string()).optional(), + token_endpoint_auth_signing_alg_values_supported: z + .array(z.string()) + .optional(), + service_documentation: z.string().optional(), + revocation_endpoint: z.string().optional(), + revocation_endpoint_auth_methods_supported: z.array(z.string()).optional(), + revocation_endpoint_auth_signing_alg_values_supported: z + .array(z.string()) + .optional(), + introspection_endpoint: z.string().optional(), + introspection_endpoint_auth_methods_supported: z + .array(z.string()) + .optional(), + introspection_endpoint_auth_signing_alg_values_supported: z + .array(z.string()) + .optional(), + code_challenge_methods_supported: z.array(z.string()).optional(), + }) + .passthrough(); + +export const OAuthTokensSchema = z + .object({ + access_token: z.string(), + token_type: z.string(), + expires_in: z.number().optional(), + scope: z.string().optional(), + refresh_token: z.string().optional(), + }) + .strip(); + +/** + * Client metadata schema according to RFC 7591 OAuth 2.0 Dynamic Client Registration + */ +export const OAuthClientMetadataSchema = z.object({ + redirect_uris: z.array(z.string()), + token_endpoint_auth_method: z.string().optional(), + grant_types: z.array(z.string()).optional(), + response_types: z.array(z.string()).optional(), + client_name: z.string().optional(), + client_uri: z.string().optional(), + logo_uri: z.string().optional(), + scope: z.string().optional(), + contacts: z.array(z.string()).optional(), + tos_uri: z.string().optional(), + policy_uri: z.string().optional(), + jwks_uri: z.string().optional(), + jwks: z.any().optional(), + software_id: z.string().optional(), + software_version: z.string().optional(), +}).passthrough(); + +/** + * Client information response schema according to RFC 7591 + */ +export const OAuthClientInformationSchema = z.object({ + client_id: z.string(), + client_secret: z.string().optional(), + client_id_issued_at: z.number().optional(), + client_secret_expires_at: z.number().optional(), +}).passthrough(); + +export type OAuthMetadata = z.infer; +export type OAuthTokens = z.infer; + +export type OAuthClientMetadata = z.infer; +export type OAuthClientInformation = z.infer; + +/** + * Implements an end-to-end OAuth client to be used with one MCP server. + * + * This client relies upon a concept of an authorized "session," the exact + * meaning of which is application-defined. Tokens, authorization codes, and + * code verifiers should not cross different sessions. + */ +export interface OAuthClientProvider { + /** + * The URL to redirect the user agent to after authorization. + */ + get redirectUrl(): string | URL; + + /** + * Metadata about this OAuth client. + */ + get clientMetadata(): OAuthClientMetadata; + + /** + * Loads information about this OAuth client, as registered already with the + * server, or returns `undefined` if the client is not registered with the + * server. + */ + clientInformation(): OAuthClientInformation | undefined | Promise; + + /** + * If implemented, this permits the OAuth client to dynamically register with + * the server. Client information saved this way should later be read via + * `clientInformation()`. + * + * This method is not required to be implemented if client information is + * statically known (e.g., pre-registered). + */ + saveClientInformation?(clientInformation: OAuthClientInformation): void | Promise; + + /** + * Loads any existing OAuth tokens for the current session, or returns + * `undefined` if there are no saved tokens. + */ + tokens(): OAuthTokens | undefined | Promise; + + /** + * Stores new OAuth tokens for the current session, after a successful + * authorization. + */ + saveTokens(tokens: OAuthTokens): void | Promise; + + /** + * Invoked to redirect the user agent to the given URL to begin the authorization flow. + */ + redirectToAuthorization(authorizationUrl: URL): void | Promise; + + /** + * Saves a PKCE code verifier for the current session, before redirecting to + * the authorization flow. + */ + saveCodeVerifier(codeVerifier: string): void | Promise; + + /** + * Loads the PKCE code verifier for the current session, necessary to validate + * the authorization result. + */ + codeVerifier(): string | Promise; +} + +export type AuthResult = "AUTHORIZED" | "REDIRECT"; + +export class UnauthorizedError extends Error { + constructor(message?: string) { + super(message ?? "Unauthorized"); + } +} + +/** + * Orchestrates the full auth flow with a server. + * + * This can be used as a single entry point for all authorization functionality, + * instead of linking together the other lower-level functions in this module. + */ +export async function auth( + provider: OAuthClientProvider, + { serverUrl, authorizationCode }: { serverUrl: string | URL, authorizationCode?: string }): Promise { + const metadata = await discoverOAuthMetadata(serverUrl); + + // Handle client registration if needed + let clientInformation = await Promise.resolve(provider.clientInformation()); + if (!clientInformation) { + if (authorizationCode !== undefined) { + throw new Error("Existing OAuth client information is required when exchanging an authorization code"); + } + + if (!provider.saveClientInformation) { + throw new Error("OAuth client information must be saveable for dynamic registration"); + } + + clientInformation = await registerClient(serverUrl, { + metadata, + clientMetadata: provider.clientMetadata, + }); + + await provider.saveClientInformation(clientInformation); + } + + // Exchange authorization code for tokens + if (authorizationCode !== undefined) { + const codeVerifier = await provider.codeVerifier(); + const tokens = await exchangeAuthorization(serverUrl, { + metadata, + clientInformation, + authorizationCode, + codeVerifier, + }); + + await provider.saveTokens(tokens); + return "AUTHORIZED"; + } + + const tokens = await provider.tokens(); + + // Handle token refresh or new authorization + if (tokens?.refresh_token) { + try { + // Attempt to refresh the token + const newTokens = await refreshAuthorization(serverUrl, { + metadata, + clientInformation, + refreshToken: tokens.refresh_token, + }); + + await provider.saveTokens(newTokens); + return "AUTHORIZED"; + } catch (error) { + console.error("Could not refresh OAuth tokens:", error); + } + } + + // Start new authorization flow + const { authorizationUrl, codeVerifier } = await startAuthorization(serverUrl, { + metadata, + clientInformation, + redirectUrl: provider.redirectUrl + }); + + await provider.saveCodeVerifier(codeVerifier); + await provider.redirectToAuthorization(authorizationUrl); + return "REDIRECT"; +} + +/** + * Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata. + * + * If the server returns a 404 for the well-known endpoint, this function will + * return `undefined`. Any other errors will be thrown as exceptions. + */ +export async function discoverOAuthMetadata( + serverUrl: string | URL, + opts?: { protocolVersion?: string }, +): Promise { + const url = new URL("/.well-known/oauth-authorization-server", serverUrl); + const response = await fetch(url, { + headers: { + "MCP-Protocol-Version": opts?.protocolVersion ?? LATEST_PROTOCOL_VERSION + } + }); + + if (response.status === 404) { + return undefined; + } + + if (!response.ok) { + throw new Error( + `HTTP ${response.status} trying to load well-known OAuth metadata`, + ); + } + + return OAuthMetadataSchema.parse(await response.json()); +} + +/** + * Begins the authorization flow with the given server, by generating a PKCE challenge and constructing the authorization URL. + */ +export async function startAuthorization( + serverUrl: string | URL, + { + metadata, + clientInformation, + redirectUrl, + }: { + metadata?: OAuthMetadata; + clientInformation: OAuthClientInformation; + redirectUrl: string | URL; + }, +): Promise<{ authorizationUrl: URL; codeVerifier: string }> { + const responseType = "code"; + const codeChallengeMethod = "S256"; + + let authorizationUrl: URL; + if (metadata) { + authorizationUrl = new URL(metadata.authorization_endpoint); + + if (!metadata.response_types_supported.includes(responseType)) { + throw new Error( + `Incompatible auth server: does not support response type ${responseType}`, + ); + } + + if ( + !metadata.code_challenge_methods_supported || + !metadata.code_challenge_methods_supported.includes(codeChallengeMethod) + ) { + throw new Error( + `Incompatible auth server: does not support code challenge method ${codeChallengeMethod}`, + ); + } + } else { + authorizationUrl = new URL("/authorize", serverUrl); + } + + // Generate PKCE challenge + const challenge = await pkceChallenge(); + const codeVerifier = challenge.code_verifier; + const codeChallenge = challenge.code_challenge; + + authorizationUrl.searchParams.set("response_type", responseType); + authorizationUrl.searchParams.set("client_id", clientInformation.client_id); + authorizationUrl.searchParams.set("code_challenge", codeChallenge); + authorizationUrl.searchParams.set( + "code_challenge_method", + codeChallengeMethod, + ); + authorizationUrl.searchParams.set("redirect_uri", String(redirectUrl)); + + return { authorizationUrl, codeVerifier }; +} + +/** + * Exchanges an authorization code for an access token with the given server. + */ +export async function exchangeAuthorization( + serverUrl: string | URL, + { + metadata, + clientInformation, + authorizationCode, + codeVerifier, + }: { + metadata?: OAuthMetadata; + clientInformation: OAuthClientInformation; + authorizationCode: string; + codeVerifier: string; + }, +): Promise { + const grantType = "authorization_code"; + + let tokenUrl: URL; + if (metadata) { + tokenUrl = new URL(metadata.token_endpoint); + + if ( + metadata.grant_types_supported && + !metadata.grant_types_supported.includes(grantType) + ) { + throw new Error( + `Incompatible auth server: does not support grant type ${grantType}`, + ); + } + } else { + tokenUrl = new URL("/token", serverUrl); + } + + // Exchange code for tokens + const params = new URLSearchParams({ + grant_type: grantType, + client_id: clientInformation.client_id, + code: authorizationCode, + code_verifier: codeVerifier, + }); + + if (clientInformation.client_secret) { + params.set("client_secret", clientInformation.client_secret); + } + + const response = await fetch(tokenUrl, { + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: params, + }); + + if (!response.ok) { + throw new Error(`Token exchange failed: HTTP ${response.status}`); + } + + return OAuthTokensSchema.parse(await response.json()); +} + +/** + * Exchange a refresh token for an updated access token. + */ +export async function refreshAuthorization( + serverUrl: string | URL, + { + metadata, + clientInformation, + refreshToken, + }: { + metadata?: OAuthMetadata; + clientInformation: OAuthClientInformation; + refreshToken: string; + }, +): Promise { + const grantType = "refresh_token"; + + let tokenUrl: URL; + if (metadata) { + tokenUrl = new URL(metadata.token_endpoint); + + if ( + metadata.grant_types_supported && + !metadata.grant_types_supported.includes(grantType) + ) { + throw new Error( + `Incompatible auth server: does not support grant type ${grantType}`, + ); + } + } else { + tokenUrl = new URL("/token", serverUrl); + } + + // Exchange refresh token + const params = new URLSearchParams({ + grant_type: grantType, + client_id: clientInformation.client_id, + refresh_token: refreshToken, + }); + + if (clientInformation.client_secret) { + params.set("client_secret", clientInformation.client_secret); + } + + const response = await fetch(tokenUrl, { + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: params, + }); + + if (!response.ok) { + throw new Error(`Token refresh failed: HTTP ${response.status}`); + } + + return OAuthTokensSchema.parse(await response.json()); +} + +/** + * Performs OAuth 2.0 Dynamic Client Registration according to RFC 7591. + */ +export async function registerClient( + serverUrl: string | URL, + { + metadata, + clientMetadata, + }: { + metadata?: OAuthMetadata; + clientMetadata: OAuthClientMetadata; + }, +): Promise { + let registrationUrl: URL; + + if (metadata) { + if (!metadata.registration_endpoint) { + throw new Error("Incompatible auth server: does not support dynamic client registration"); + } + + registrationUrl = new URL(metadata.registration_endpoint); + } else { + registrationUrl = new URL("/register", serverUrl); + } + + const response = await fetch(registrationUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(clientMetadata), + }); + + if (!response.ok) { + throw new Error(`Dynamic client registration failed: HTTP ${response.status}`); + } + + return OAuthClientInformationSchema.parse(await response.json()); +} \ No newline at end of file diff --git a/src/client/sse.test.ts b/src/client/sse.test.ts index f59c45fe..57497013 100644 --- a/src/client/sse.test.ts +++ b/src/client/sse.test.ts @@ -2,6 +2,7 @@ import { createServer, type IncomingMessage, type Server } from "http"; import { AddressInfo } from "net"; import { JSONRPCMessage } from "../types.js"; import { SSEClientTransport } from "./sse.js"; +import { OAuthClientProvider, OAuthTokens, UnauthorizedError } from "./auth.js"; describe("SSEClientTransport", () => { let server: Server; @@ -59,6 +60,8 @@ describe("SSEClientTransport", () => { afterEach(async () => { await transport.close(); await server.close(); + + jest.clearAllMocks(); }); describe("connection handling", () => { @@ -72,8 +75,7 @@ describe("SSEClientTransport", () => { it("rejects if server returns non-200 status", async () => { // Create a server that returns 403 - server.close(); - await new Promise((resolve) => server.on("close", resolve)); + await server.close(); server = createServer((req, res) => { res.writeHead(403); @@ -174,8 +176,7 @@ describe("SSEClientTransport", () => { it("handles POST request failures", async () => { // Create a server that returns 500 for POST - server.close(); - await new Promise((resolve) => server.on("close", resolve)); + await server.close(); server = createServer((req, res) => { if (req.method === "GET") { @@ -251,11 +252,92 @@ describe("SSEClientTransport", () => { await transport.start(); - // Mock fetch for the message sending test - global.fetch = jest.fn().mockResolvedValue({ - ok: true, + // Store original fetch + const originalFetch = global.fetch; + + try { + // Mock fetch for the message sending test + global.fetch = jest.fn().mockResolvedValue({ + ok: true, + }); + + const message: JSONRPCMessage = { + jsonrpc: "2.0", + id: "1", + method: "test", + params: {}, + }; + + await transport.send(message); + + // Verify fetch was called with correct headers + expect(global.fetch).toHaveBeenCalledWith( + expect.any(URL), + expect.objectContaining({ + headers: expect.any(Headers), + }), + ); + + const calledHeaders = (global.fetch as jest.Mock).mock.calls[0][1] + .headers; + expect(calledHeaders.get("Authorization")).toBe( + customHeaders.Authorization, + ); + expect(calledHeaders.get("X-Custom-Header")).toBe( + customHeaders["X-Custom-Header"], + ); + expect(calledHeaders.get("content-type")).toBe("application/json"); + } finally { + // Restore original fetch + global.fetch = originalFetch; + } + }); + }); + + describe("auth handling", () => { + let mockAuthProvider: jest.Mocked; + + beforeEach(() => { + mockAuthProvider = { + get redirectUrl() { return "http://localhost/callback"; }, + get clientMetadata() { return { redirect_uris: ["http://localhost/callback"] }; }, + clientInformation: jest.fn(() => ({ client_id: "test-client-id", client_secret: "test-client-secret" })), + tokens: jest.fn(), + saveTokens: jest.fn(), + redirectToAuthorization: jest.fn(), + saveCodeVerifier: jest.fn(), + codeVerifier: jest.fn(), + }; + }); + + it("attaches auth header from provider on SSE connection", async () => { + mockAuthProvider.tokens.mockResolvedValue({ + access_token: "test-token", + token_type: "Bearer" + }); + + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider, + }); + + await transport.start(); + + expect(lastServerRequest.headers.authorization).toBe("Bearer test-token"); + expect(mockAuthProvider.tokens).toHaveBeenCalled(); + }); + + it("attaches auth header from provider on POST requests", async () => { + mockAuthProvider.tokens.mockResolvedValue({ + access_token: "test-token", + token_type: "Bearer" + }); + + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider, }); + await transport.start(); + const message: JSONRPCMessage = { jsonrpc: "2.0", id: "1", @@ -265,23 +347,376 @@ describe("SSEClientTransport", () => { await transport.send(message); - // Verify fetch was called with correct headers - expect(global.fetch).toHaveBeenCalledWith( - expect.any(URL), - expect.objectContaining({ - headers: expect.any(Headers), - }), - ); + expect(lastServerRequest.headers.authorization).toBe("Bearer test-token"); + expect(mockAuthProvider.tokens).toHaveBeenCalled(); + }); - const calledHeaders = (global.fetch as jest.Mock).mock.calls[0][1] - .headers; - expect(calledHeaders.get("Authorization")).toBe( - customHeaders.Authorization, - ); - expect(calledHeaders.get("X-Custom-Header")).toBe( - customHeaders["X-Custom-Header"], - ); - expect(calledHeaders.get("content-type")).toBe("application/json"); + it("attempts auth flow on 401 during SSE connection", async () => { + // Create server that returns 401s + await server.close(); + + server = createServer((req, res) => { + lastServerRequest = req; + if (req.url !== "/") { + res.writeHead(404).end(); + } else { + res.writeHead(401).end(); + } + }); + + await new Promise(resolve => { + server.listen(0, "127.0.0.1", () => { + const addr = server.address() as AddressInfo; + baseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider, + }); + + await expect(() => transport.start()).rejects.toThrow(UnauthorizedError); + expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1); + }); + + it("attempts auth flow on 401 during POST request", async () => { + // Create server that accepts SSE but returns 401 on POST + await server.close(); + + server = createServer((req, res) => { + lastServerRequest = req; + + switch (req.method) { + case "GET": + if (req.url !== "/") { + res.writeHead(404).end(); + return; + } + + res.writeHead(200, { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }); + res.write("event: endpoint\n"); + res.write(`data: ${baseUrl.href}\n\n`); + break; + + case "POST": + res.writeHead(401); + res.end(); + break; + } + }); + + await new Promise(resolve => { + server.listen(0, "127.0.0.1", () => { + const addr = server.address() as AddressInfo; + baseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider, + }); + + await transport.start(); + + const message: JSONRPCMessage = { + jsonrpc: "2.0", + id: "1", + method: "test", + params: {}, + }; + + await expect(() => transport.send(message)).rejects.toThrow(UnauthorizedError); + expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1); + }); + + it("respects custom headers when using auth provider", async () => { + mockAuthProvider.tokens.mockResolvedValue({ + access_token: "test-token", + token_type: "Bearer" + }); + + const customHeaders = { + "X-Custom-Header": "custom-value", + }; + + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider, + requestInit: { + headers: customHeaders, + }, + }); + + await transport.start(); + + const message: JSONRPCMessage = { + jsonrpc: "2.0", + id: "1", + method: "test", + params: {}, + }; + + await transport.send(message); + + expect(lastServerRequest.headers.authorization).toBe("Bearer test-token"); + expect(lastServerRequest.headers["x-custom-header"]).toBe("custom-value"); + }); + + it("refreshes expired token during SSE connection", async () => { + // Mock tokens() to return expired token until saveTokens is called + let currentTokens: OAuthTokens = { + access_token: "expired-token", + token_type: "Bearer", + refresh_token: "refresh-token" + }; + mockAuthProvider.tokens.mockImplementation(() => currentTokens); + mockAuthProvider.saveTokens.mockImplementation((tokens) => { + currentTokens = tokens; + }); + + // Create server that returns 401 for expired token, then accepts new token + await server.close(); + + let connectionAttempts = 0; + server = createServer((req, res) => { + lastServerRequest = req; + + if (req.url === "/token" && req.method === "POST") { + // Handle token refresh request + let body = ""; + req.on("data", chunk => { body += chunk; }); + req.on("end", () => { + const params = new URLSearchParams(body); + if (params.get("grant_type") === "refresh_token" && + params.get("refresh_token") === "refresh-token" && + params.get("client_id") === "test-client-id" && + params.get("client_secret") === "test-client-secret") { + res.writeHead(200, { "Content-Type": "application/json" }); + res.end(JSON.stringify({ + access_token: "new-token", + token_type: "Bearer", + refresh_token: "new-refresh-token" + })); + } else { + res.writeHead(400).end(); + } + }); + return; + } + + if (req.url !== "/") { + res.writeHead(404).end(); + return; + } + + const auth = req.headers.authorization; + if (auth === "Bearer expired-token") { + res.writeHead(401).end(); + return; + } + + if (auth === "Bearer new-token") { + res.writeHead(200, { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }); + res.write("event: endpoint\n"); + res.write(`data: ${baseUrl.href}\n\n`); + connectionAttempts++; + return; + } + + res.writeHead(401).end(); + }); + + await new Promise(resolve => { + server.listen(0, "127.0.0.1", () => { + const addr = server.address() as AddressInfo; + baseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider, + }); + + await transport.start(); + + expect(mockAuthProvider.saveTokens).toHaveBeenCalledWith({ + access_token: "new-token", + token_type: "Bearer", + refresh_token: "new-refresh-token" + }); + expect(connectionAttempts).toBe(1); + expect(lastServerRequest.headers.authorization).toBe("Bearer new-token"); + }); + + it("refreshes expired token during POST request", async () => { + // Mock tokens() to return expired token until saveTokens is called + let currentTokens: OAuthTokens = { + access_token: "expired-token", + token_type: "Bearer", + refresh_token: "refresh-token" + }; + mockAuthProvider.tokens.mockImplementation(() => currentTokens); + mockAuthProvider.saveTokens.mockImplementation((tokens) => { + currentTokens = tokens; + }); + + // Create server that accepts SSE but returns 401 on POST with expired token + await server.close(); + + let postAttempts = 0; + server = createServer((req, res) => { + lastServerRequest = req; + + if (req.url === "/token" && req.method === "POST") { + // Handle token refresh request + let body = ""; + req.on("data", chunk => { body += chunk; }); + req.on("end", () => { + const params = new URLSearchParams(body); + if (params.get("grant_type") === "refresh_token" && + params.get("refresh_token") === "refresh-token" && + params.get("client_id") === "test-client-id" && + params.get("client_secret") === "test-client-secret") { + res.writeHead(200, { "Content-Type": "application/json" }); + res.end(JSON.stringify({ + access_token: "new-token", + token_type: "Bearer", + refresh_token: "new-refresh-token" + })); + } else { + res.writeHead(400).end(); + } + }); + return; + } + + switch (req.method) { + case "GET": + if (req.url !== "/") { + res.writeHead(404).end(); + return; + } + + res.writeHead(200, { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }); + res.write("event: endpoint\n"); + res.write(`data: ${baseUrl.href}\n\n`); + break; + + case "POST": { + if (req.url !== "/") { + res.writeHead(404).end(); + return; + } + + const auth = req.headers.authorization; + if (auth === "Bearer expired-token") { + res.writeHead(401).end(); + return; + } + + if (auth === "Bearer new-token") { + res.writeHead(200).end(); + postAttempts++; + return; + } + + res.writeHead(401).end(); + break; + } + } + }); + + await new Promise(resolve => { + server.listen(0, "127.0.0.1", () => { + const addr = server.address() as AddressInfo; + baseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider, + }); + + await transport.start(); + + const message: JSONRPCMessage = { + jsonrpc: "2.0", + id: "1", + method: "test", + params: {}, + }; + + await transport.send(message); + + expect(mockAuthProvider.saveTokens).toHaveBeenCalledWith({ + access_token: "new-token", + token_type: "Bearer", + refresh_token: "new-refresh-token" + }); + expect(postAttempts).toBe(1); + expect(lastServerRequest.headers.authorization).toBe("Bearer new-token"); + }); + + it("redirects to authorization if refresh token flow fails", async () => { + // Mock tokens() to return expired token until saveTokens is called + let currentTokens: OAuthTokens = { + access_token: "expired-token", + token_type: "Bearer", + refresh_token: "refresh-token" + }; + mockAuthProvider.tokens.mockImplementation(() => currentTokens); + mockAuthProvider.saveTokens.mockImplementation((tokens) => { + currentTokens = tokens; + }); + + // Create server that returns 401 for all tokens + await server.close(); + + server = createServer((req, res) => { + lastServerRequest = req; + + if (req.url === "/token" && req.method === "POST") { + // Handle token refresh request - always fail + res.writeHead(400).end(); + return; + } + + if (req.url !== "/") { + res.writeHead(404).end(); + return; + } + res.writeHead(401).end(); + }); + + await new Promise(resolve => { + server.listen(0, "127.0.0.1", () => { + const addr = server.address() as AddressInfo; + baseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider, + }); + + await expect(() => transport.start()).rejects.toThrow(UnauthorizedError); + expect(mockAuthProvider.redirectToAuthorization).toHaveBeenCalled(); }); }); }); diff --git a/src/client/sse.ts b/src/client/sse.ts index 14921f57..5e9f0cf0 100644 --- a/src/client/sse.ts +++ b/src/client/sse.ts @@ -1,6 +1,7 @@ import { EventSource, type ErrorEvent, type EventSourceInit } from "eventsource"; import { Transport } from "../shared/transport.js"; import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; +import { auth, AuthResult, OAuthClientProvider, UnauthorizedError } from "./auth.js"; export class SseError extends Error { constructor( @@ -12,6 +13,42 @@ export class SseError extends Error { } } +/** + * Configuration options for the `SSEClientTransport`. + */ +export type SSEClientTransportOptions = { + /** + * An OAuth client provider to use for authentication. + * + * When an `authProvider` is specified and the SSE connection is started: + * 1. The connection is attempted with any existing access token from the `authProvider`. + * 2. If the access token has expired, the `authProvider` is used to refresh the token. + * 3. If token refresh fails or no access token exists, and auth is required, `OAuthClientProvider.redirectToAuthorization` is called, and an `UnauthorizedError` will be thrown from `connect`/`start`. + * + * After the user has finished authorizing via their user agent, and is redirected back to the MCP client application, call `SSEClientTransport.finishAuth` with the authorization code before retrying the connection. + * + * If an `authProvider` is not provided, and auth is required, an `UnauthorizedError` will be thrown. + * + * `UnauthorizedError` might also be thrown when sending any message over the SSE transport, indicating that the session has expired, and needs to be re-authed and reconnected. + */ + authProvider?: OAuthClientProvider; + + /** + * Customizes the initial SSE request to the server (the request that begins the stream). + * + * NOTE: Setting this property will prevent an `Authorization` header from + * being automatically attached to the SSE request, if an `authProvider` is + * also given. This can be worked around by setting the `Authorization` header + * manually. + */ + eventSourceInit?: EventSourceInit; + + /** + * Customizes recurring POST requests to the server. + */ + requestInit?: RequestInit; +}; + /** * Client transport for SSE: this will connect to a server using Server-Sent Events for receiving * messages and make separate POST requests for sending messages. @@ -23,6 +60,7 @@ export class SSEClientTransport implements Transport { private _url: URL; private _eventSourceInit?: EventSourceInit; private _requestInit?: RequestInit; + private _authProvider?: OAuthClientProvider; onclose?: () => void; onerror?: (error: Error) => void; @@ -30,28 +68,68 @@ export class SSEClientTransport implements Transport { constructor( url: URL, - opts?: { eventSourceInit?: EventSourceInit; requestInit?: RequestInit }, + opts?: SSEClientTransportOptions, ) { this._url = url; this._eventSourceInit = opts?.eventSourceInit; this._requestInit = opts?.requestInit; + this._authProvider = opts?.authProvider; } - start(): Promise { - if (this._eventSource) { - throw new Error( - "SSEClientTransport already started! If using Client class, note that connect() calls start() automatically.", - ); + private async _authThenStart(): Promise { + if (!this._authProvider) { + throw new UnauthorizedError("No auth provider"); + } + + let result: AuthResult; + try { + result = await auth(this._authProvider, { serverUrl: this._url }); + } catch (error) { + this.onerror?.(error as Error); + throw error; + } + + if (result !== "AUTHORIZED") { + throw new UnauthorizedError(); } + return await this._startOrAuth(); + } + + private async _commonHeaders(): Promise { + const headers: HeadersInit = {}; + if (this._authProvider) { + const tokens = await this._authProvider.tokens(); + if (tokens) { + headers["Authorization"] = `Bearer ${tokens.access_token}`; + } + } + + return headers; + } + + private _startOrAuth(): Promise { return new Promise((resolve, reject) => { this._eventSource = new EventSource( this._url.href, - this._eventSourceInit, + this._eventSourceInit ?? { + fetch: (url, init) => this._commonHeaders().then((headers) => fetch(url, { + ...init, + headers: { + ...headers, + Accept: "text/event-stream" + } + })), + }, ); this._abortController = new AbortController(); this._eventSource.onerror = (event) => { + if (event.code === 401 && this._authProvider) { + this._authThenStart().then(resolve, reject); + return; + } + const error = new SseError(event.code, event.message, event); reject(error); this.onerror?.(error); @@ -97,6 +175,30 @@ export class SSEClientTransport implements Transport { }); } + async start() { + if (this._eventSource) { + throw new Error( + "SSEClientTransport already started! If using Client class, note that connect() calls start() automatically.", + ); + } + + return await this._startOrAuth(); + } + + /** + * Call this method after the user has finished authorizing via their user agent and is redirected back to the MCP client application. This will exchange the authorization code for an access token, enabling the next connection attempt to successfully auth. + */ + async finishAuth(authorizationCode: string): Promise { + if (!this._authProvider) { + throw new UnauthorizedError("No auth provider"); + } + + const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode }); + if (result !== "AUTHORIZED") { + throw new UnauthorizedError("Failed to authorize"); + } + } + async close(): Promise { this._abortController?.abort(); this._eventSource?.close(); @@ -109,7 +211,8 @@ export class SSEClientTransport implements Transport { } try { - const headers = new Headers(this._requestInit?.headers); + const commonHeaders = await this._commonHeaders(); + const headers = new Headers({ ...commonHeaders, ...this._requestInit?.headers }); headers.set("content-type", "application/json"); const init = { ...this._requestInit, @@ -120,8 +223,17 @@ export class SSEClientTransport implements Transport { }; const response = await fetch(this._endpoint, init); - if (!response.ok) { + if (response.status === 401 && this._authProvider) { + const result = await auth(this._authProvider, { serverUrl: this._url }); + if (result !== "AUTHORIZED") { + throw new UnauthorizedError(); + } + + // Purposely _not_ awaited, so we don't call onerror twice + return this.send(message); + } + const text = await response.text().catch(() => null); throw new Error( `Error POSTing to endpoint (HTTP ${response.status}): ${text}`, diff --git a/tsconfig.cjs.json b/tsconfig.cjs.json index 058a5d9a..b2f344a8 100644 --- a/tsconfig.cjs.json +++ b/tsconfig.cjs.json @@ -5,5 +5,5 @@ "moduleResolution": "node", "outDir": "./dist/cjs" }, - "exclude": ["**/*.test.ts"] + "exclude": ["**/*.test.ts", "src/__mocks__/**/*"] } diff --git a/tsconfig.prod.json b/tsconfig.prod.json index 2c68666e..2302dd84 100644 --- a/tsconfig.prod.json +++ b/tsconfig.prod.json @@ -3,5 +3,5 @@ "compilerOptions": { "outDir": "./dist/esm" }, - "exclude": ["**/*.test.ts"] + "exclude": ["**/*.test.ts", "src/__mocks__/**/*"] }