diff --git a/.prettierrc b/.prettierrc new file mode 100644 index 00000000..4379c748 --- /dev/null +++ b/.prettierrc @@ -0,0 +1,8 @@ +{ + "printWidth": 80, + "tabWidth": 2, + "trailingComma": "all", + "jsxBracketSameLine": true, + "semi": true, + "singleQuote": false +} diff --git a/jest.config.js b/jest.config.js index f8f621c8..a0021104 100644 --- a/jest.config.js +++ b/jest.config.js @@ -12,5 +12,6 @@ export default { transformIgnorePatterns: [ "/node_modules/(?!eventsource)/" ], + collectCoverageFrom: ["src/**/*.ts"], testPathIgnorePatterns: ["/node_modules/", "/dist/"], }; diff --git a/package-lock.json b/package-lock.json index 73f1cbba..8338e3c4 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@modelcontextprotocol/sdk", - "version": "1.7.0", + "version": "1.8.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@modelcontextprotocol/sdk", - "version": "1.7.0", + "version": "1.8.0", "license": "MIT", "dependencies": { "content-type": "^1.0.5", diff --git a/package.json b/package.json index e2d8b3d7..d14fd9e6 100644 --- a/package.json +++ b/package.json @@ -41,6 +41,7 @@ "prepack": "npm run build:esm && npm run build:cjs", "lint": "eslint src/", "test": "jest", + "coverage": "jest --coverage", "start": "npm run server", "server": "tsx watch --clear-screen=false src/cli.ts server", "client": "tsx src/cli.ts client" diff --git a/src/inMemory.ts b/src/inMemory.ts index 106a9e7e..65915baa 100644 --- a/src/inMemory.ts +++ b/src/inMemory.ts @@ -12,6 +12,7 @@ export class InMemoryTransport implements Transport { onerror?: (error: Error) => void; onmessage?: (message: JSONRPCMessage) => void; sessionId?: string; + user?: unknown; /** * Creates a pair of linked in-memory transports that can communicate with each other. One should be passed to a Client and one to a Server. diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index 2e91a568..08518b20 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -1,7 +1,8 @@ -import { McpServer } from "./mcp.js"; +import { McpServer, ToolCallback } from "./mcp.js"; import { Client } from "../client/index.js"; import { InMemoryTransport } from "../inMemory.js"; -import { z } from "zod"; +import { z, ZodRawShape } from "zod"; +import { zodToJsonSchema } from "zod-to-json-schema"; import { ListToolsResultSchema, CallToolResultSchema, @@ -11,10 +12,16 @@ import { ListPromptsResultSchema, GetPromptResultSchema, CompleteResultSchema, + CallToolRequestSchema, + CallToolRequest, + ListToolsRequestSchema, + ListToolsResult, + Tool, } from "../types.js"; import { ResourceTemplate } from "./mcp.js"; import { completable } from "./completable.js"; import { UriTemplate } from "../shared/uriTemplate.js"; +import { RequestHandlerExtra } from "../shared/protocol.js"; describe("McpServer", () => { test("should expose underlying Server instance", () => { @@ -318,7 +325,7 @@ describe("tool()", () => { // This should succeed mcpServer.tool("tool1", () => ({ content: [] })); - + // This should also succeed and not throw about request handlers mcpServer.tool("tool2", () => ({ content: [] })); }); @@ -354,7 +361,8 @@ describe("tool()", () => { }; }); - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); // Set a test sessionId on the server transport serverTransport.sessionId = "test-session-123"; @@ -815,7 +823,7 @@ describe("resource()", () => { }, ], })); - + // This should also succeed and not throw about request handlers mcpServer.resource("resource2", "test://resource2", async () => ({ contents: [ @@ -1321,7 +1329,7 @@ describe("prompt()", () => { }, ], })); - + // This should also succeed and not throw about request handlers mcpServer.prompt("prompt2", async () => ({ messages: [ @@ -1343,19 +1351,17 @@ describe("prompt()", () => { }); // This should succeed - mcpServer.prompt( - "echo", - { message: z.string() }, - ({ message }) => ({ - messages: [{ + mcpServer.prompt("echo", { message: z.string() }, ({ message }) => ({ + messages: [ + { role: "user", content: { type: "text", - text: `Please process this message: ${message}` - } - }] - }) - ); + text: `Please process this message: ${message}`, + }, + }, + ], + })); }); test("should allow registering both resources and prompts with completion handlers", () => { @@ -1388,14 +1394,16 @@ describe("prompt()", () => { "echo", { message: completable(z.string(), () => ["hello", "world"]) }, ({ message }) => ({ - messages: [{ - role: "user", - content: { - type: "text", - text: `Please process this message: ${message}` - } - }] - }) + messages: [ + { + role: "user", + content: { + type: "text", + text: `Please process this message: ${message}`, + }, + }, + ], + }), ); }); @@ -1582,3 +1590,351 @@ describe("prompt()", () => { expect(result.completion.total).toBe(1); }); }); + +describe("McpServer with Auth Extension", () => { + type SessionUser = { + role: string; + [key: string]: unknown; + }; + + type AccessPolicy = { + allow?: { + roles?: string[]; + }; + deny?: { + roles?: string[]; + }; + }; + + type RegisteredToolWithAuth = { + description?: string; + inputSchema?: z.ZodObject; + callback: ToolCallback; + accessPolicy?: AccessPolicy; + }; + + // Just a simple extension to McpServer that adds support for access policies on tools + class McpServerWithAuth extends McpServer { + protected override _registeredTools: { + [name: string]: RegisteredToolWithAuth; + } = {}; + checkPermissions(user?: SessionUser, policy?: AccessPolicy): boolean { + if (!policy) { + return true; + } + + if (!user) { + return false; + } + + // Check deny rules first + if (policy.deny) { + // Check denied roles + if (policy.deny.roles?.includes(user.role)) { + return false; + } + } + + // Check allow rules + if (policy.allow) { + let isAllowed = false; + + // If no allow rules are specified, default to allowed + if (!policy.allow.roles) { + isAllowed = true; + } else { + // Check allowed roles + if (policy.allow.roles?.includes(user.role)) { + isAllowed = true; + } + } + + return isAllowed; + } + + // If no rules specified, default to allowed + return true; + } + + override tool( + name: string, + cb: ToolCallback, + accessPolicy?: AccessPolicy, + ): void; + override tool( + name: string, + description: string, + cb: ToolCallback, + accessPolicy?: AccessPolicy, + ): void; + override tool( + name: string, + paramsSchema: Args, + cb: ToolCallback, + accessPolicy?: AccessPolicy, + ): void; + override tool( + name: string, + description: string, + paramsSchema: Args, + cb: ToolCallback, + accessPolicy?: AccessPolicy, + ): void; + override tool(name: string, ...rest: unknown[]): void { + let description: string | undefined; + let paramsSchema: ZodRawShape | undefined; + let accessPolicy: AccessPolicy | undefined; + let cb: ToolCallback; + + // Parse arguments based on their types + if (typeof rest[0] === "function") { + // Case: tool(name, cb, accessPolicy?) + cb = rest[0] as ToolCallback; + accessPolicy = rest[1] as AccessPolicy | undefined; + } else if (typeof rest[0] === "string") { + // Cases with description + description = rest[0]; + if (typeof rest[1] === "function") { + // Case: tool(name, description, cb, accessPolicy?) + cb = rest[1] as ToolCallback; + accessPolicy = rest[2] as AccessPolicy | undefined; + } else { + // Case: tool(name, description, paramsSchema, cb, accessPolicy?) + paramsSchema = rest[1] as ZodRawShape; + cb = rest[2] as ToolCallback; + accessPolicy = rest[3] as AccessPolicy | undefined; + } + } else { + // Case: tool(name, paramsSchema, cb, accessPolicy?) + paramsSchema = rest[0] as ZodRawShape; + cb = rest[1] as ToolCallback; + accessPolicy = rest[2] as AccessPolicy | undefined; + } + + // Register with base class + const args: unknown[] = [name]; + if (description) args.push(description); + if (paramsSchema) args.push(paramsSchema); + args.push(cb); + + // Set up request handlers if not already initialized + if (!this._toolHandlersInitialized) { + this.server.assertCanSetRequestHandler( + CallToolRequestSchema.shape.method.value, + ); + this.server.assertCanSetRequestHandler( + ListToolsRequestSchema.shape.method.value, + ); + this.server.registerCapabilities({ tools: {} }); + + // Add ListToolsRequestSchema handler + this.server.setRequestHandler( + ListToolsRequestSchema, + (request, extra): ListToolsResult => { + const user = extra.user as SessionUser | undefined; + + // Filter tools based on permissions + const accessibleTools = Object.entries(this._registeredTools) + .filter(([_, tool]) => + this.checkPermissions(user, tool.accessPolicy), + ) + .map( + ([name, tool]): Tool => ({ + name, + description: tool.description, + inputSchema: tool.inputSchema + ? (zodToJsonSchema(tool.inputSchema, { + strictUnions: true, + }) as Tool["inputSchema"]) + : { type: "object" }, + }), + ); + + return { tools: accessibleTools }; + }, + ); + + this.server.setRequestHandler( + CallToolRequestSchema, + async (request: CallToolRequest, extra: RequestHandlerExtra) => { + const tool = this._registeredTools[request.params.name]; + if (!tool) { + throw new Error(`Tool ${request.params.name} not found`); + } + + if ( + !this.checkPermissions( + extra.user as SessionUser, + tool.accessPolicy, + ) + ) { + throw new Error(`Access denied for tool: ${request.params.name}`); + } + + if (tool.inputSchema) { + const parseResult = await tool.inputSchema.safeParseAsync( + request.params.arguments, + ); + if (!parseResult.success) { + throw new Error( + `Invalid arguments for tool ${request.params.name}: ${parseResult.error.message}`, + ); + } + + const args = parseResult.data; + const cb = tool.callback as ToolCallback; + return await Promise.resolve(cb(args, extra)); + } else { + const cb = tool.callback as ToolCallback; + return await Promise.resolve(cb(extra)); + } + }, + ); + this._toolHandlersInitialized = true; + } + + McpServer.prototype.tool.apply( + this, + args as Parameters, + ); + this._registeredTools[name].accessPolicy = accessPolicy; + } + } + + const mcpServer = new McpServerWithAuth({ + name: "test server with auth", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.tool("public-tool", async () => ({ + content: [ + { + type: "text", + text: "Public tool response", + }, + ], + })); + + mcpServer.tool( + "protected-tool", + async () => ({ + content: [ + { + type: "text", + text: "Protected tool response", + }, + ], + }), + { + allow: { + roles: ["admin"], + }, + }, + ); + + test("should public tools work with list and call when unauthenticated", async () => { + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + // Public tool should be accessible + const result = await client.request( + { method: "tools/list" }, + ListToolsResultSchema, + ); + expect(result.tools).toHaveLength(1); + expect(result.tools[0].name).toBe("public-tool"); + const response = await client.request( + { + method: "tools/call", + params: { + name: "public-tool", + arguments: {}, + }, + }, + CallToolResultSchema, + ); + expect(response.content).toEqual([ + { + type: "text", + text: "Public tool response", + }, + ]); + }); + + test("should public tools work with list and call when unauthorized", async () => { + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + // Protected tool should be inaccessible when authenticated as a non-admin user + serverTransport.user = { role: "member" }; + const result = await client.request( + { method: "tools/list" }, + ListToolsResultSchema, + ); + expect(result.tools).toHaveLength(1); + expect(result.tools[0].name).toBe("public-tool"); + const response = await client.request( + { + method: "tools/call", + params: { + name: "public-tool", + arguments: {}, + }, + }, + CallToolResultSchema, + ); + expect(response.content).toEqual([ + { + type: "text", + text: "Public tool response", + }, + ]); + }); + + test("should protected tools work with list and call when authorized", async () => { + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + serverTransport.user = { role: "admin" }; + const result = await client.request( + { method: "tools/list" }, + ListToolsResultSchema, + ); + expect(result.tools).toHaveLength(2); + expect(result.tools[0].name).toBe("public-tool"); + expect(result.tools[1].name).toBe("protected-tool"); + + const response = await client.request( + { + method: "tools/call", + params: { + name: "protected-tool", + arguments: {}, + }, + }, + CallToolResultSchema, + ); + expect(response.content).toEqual([ + { + type: "text", + text: "Protected tool response", + }, + ]); + }); +}); diff --git a/src/server/mcp.ts b/src/server/mcp.ts index 8f4a909c..ed93256f 100644 --- a/src/server/mcp.ts +++ b/src/server/mcp.ts @@ -54,12 +54,17 @@ export class McpServer { */ public readonly server: Server; - private _registeredResources: { [uri: string]: RegisteredResource } = {}; - private _registeredResourceTemplates: { + protected _registeredResources: { [uri: string]: RegisteredResource } = {}; + protected _registeredResourceTemplates: { [name: string]: RegisteredResourceTemplate; } = {}; - private _registeredTools: { [name: string]: RegisteredTool } = {}; - private _registeredPrompts: { [name: string]: RegisteredPrompt } = {}; + protected _registeredTools: { [name: string]: RegisteredTool } = {}; + protected _registeredPrompts: { [name: string]: RegisteredPrompt } = {}; + + protected _toolHandlersInitialized = false; + protected _completionHandlerInitialized = false; + protected _resourceHandlersInitialized = false; + protected _promptHandlersInitialized = false; constructor(serverInfo: Implementation, options?: ServerOptions) { this.server = new Server(serverInfo, options); @@ -81,13 +86,11 @@ export class McpServer { await this.server.close(); } - private _toolHandlersInitialized = false; - private setToolRequestHandlers() { if (this._toolHandlersInitialized) { return; } - + this.server.assertCanSetRequestHandler( ListToolsRequestSchema.shape.method.value, ); @@ -177,8 +180,6 @@ export class McpServer { this._toolHandlersInitialized = true; } - private _completionHandlerInitialized = false; - private setCompletionRequestHandler() { if (this._completionHandlerInitialized) { return; @@ -267,8 +268,6 @@ export class McpServer { return createCompletionResult(suggestions); } - private _resourceHandlersInitialized = false; - private setResourceRequestHandlers() { if (this._resourceHandlersInitialized) { return; @@ -366,12 +365,10 @@ export class McpServer { ); this.setCompletionRequestHandler(); - + this._resourceHandlersInitialized = true; } - private _promptHandlersInitialized = false; - private setPromptRequestHandlers() { if (this._promptHandlersInitialized) { return; @@ -438,7 +435,7 @@ export class McpServer { ); this.setCompletionRequestHandler(); - + this._promptHandlersInitialized = true; } @@ -770,7 +767,7 @@ type RegisteredPrompt = { callback: PromptCallback; }; -function promptArgumentsFromSchema( +export function promptArgumentsFromSchema( schema: ZodObject, ): PromptArgument[] { return Object.entries(schema.shape).map( diff --git a/src/server/sse.ts b/src/server/sse.ts index 84c1cbb9..73603021 100644 --- a/src/server/sse.ts +++ b/src/server/sse.ts @@ -19,6 +19,7 @@ export class SSEServerTransport implements Transport { onclose?: () => void; onerror?: (error: Error) => void; onmessage?: (message: JSONRPCMessage) => void; + user?: unknown; /** * Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL identified by `_endpoint`. diff --git a/src/server/stdio.ts b/src/server/stdio.ts index 30c80012..7c3b21c6 100644 --- a/src/server/stdio.ts +++ b/src/server/stdio.ts @@ -21,6 +21,7 @@ export class StdioServerTransport implements Transport { onclose?: () => void; onerror?: (error: Error) => void; onmessage?: (message: JSONRPCMessage) => void; + user?: unknown; // Arrow functions to bind `this` properly, while maintaining function identity. _ondata = (chunk: Buffer) => { @@ -73,7 +74,7 @@ export class StdioServerTransport implements Transport { // This prevents interfering with other parts of the application that might be using stdin this._stdin.pause(); } - + // Clear the buffer and notify closure this._readBuffer.clear(); this.onclose?.(); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index a6e47184..5540d4fe 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -93,6 +93,11 @@ export type RequestHandlerExtra = { * The session ID from the transport, if available. */ sessionId?: string; + + /** + * The authenticated user, if available. + */ + user?: unknown; }; /** @@ -319,6 +324,7 @@ export abstract class Protocol< const extra: RequestHandlerExtra = { signal: abortController.signal, sessionId: this._transport?.sessionId, + user: this._transport?.user, }; // Starting with Promise.resolve() puts any synchronous errors into the monad as well. @@ -364,7 +370,7 @@ export abstract class Protocol< private _onprogress(notification: ProgressNotification): void { const { progressToken, ...params } = notification.params; const messageId = Number(progressToken); - + const handler = this._progressHandlers.get(messageId); if (!handler) { this._onerror(new Error(`Received a progress notification for an unknown token: ${JSON.stringify(notification)}`)); diff --git a/src/shared/transport.ts b/src/shared/transport.ts index b80e2a51..88cea7f2 100644 --- a/src/shared/transport.ts +++ b/src/shared/transport.ts @@ -46,4 +46,9 @@ export interface Transport { * The session ID generated for this connection. */ sessionId?: string; + + /** + * The authenticated user for this transport session. + */ + user?: unknown; }