diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index fb5ecd13..c14a3961 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -7,6 +7,7 @@ import { Request, Result, ServerCapabilities, + JSONRPCRequest, } from "../types.js"; import { Protocol, mergeCapabilities } from "./protocol.js"; import { Transport } from "./transport.js"; @@ -16,6 +17,7 @@ class MockTransport implements Transport { onclose?: () => void; onerror?: (error: Error) => void; onmessage?: (message: unknown) => void; + sessionId?: string; async start(): Promise {} async close(): Promise { @@ -256,6 +258,75 @@ describe("protocol tests", () => { await expect(requestPromise).resolves.toEqual({ result: "success" }); }); }); + + describe("request handler params preservation", () => { + let protocol: Protocol; + let transport: MockTransport; + + beforeEach(() => { + transport = new MockTransport(); + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + })(); + protocol.connect(transport); + }); + + it("should preserve request params when passed to handler via setRequestHandler", async () => { + const testMethod = "test/paramsMethod"; + const testParams = { key1: "value1", key2: 123 }; + const testId = 99; + + const requestSchema = z.object({ + method: z.literal(testMethod), + params: z.object({ + key1: z.string(), + key2: z.number(), + }), + }); + + const mockHandler = jest.fn().mockResolvedValue({ success: true }); + + protocol.setRequestHandler(requestSchema, mockHandler); + + const rawIncomingRequest: JSONRPCRequest = { + jsonrpc: "2.0", + id: testId, + method: testMethod, + params: testParams, + }; + + if (!transport.onmessage) { + throw new Error("transport.onmessage was not set by protocol.connect()"); + } + transport.onmessage(rawIncomingRequest); + + await new Promise(setImmediate); + + expect(mockHandler).toHaveBeenCalledTimes(1); + + const handlerArg = mockHandler.mock.calls[0][0]; + const handlerExtraArg = mockHandler.mock.calls[0][1]; + + expect(handlerArg).toEqual( + expect.objectContaining({ + jsonrpc: "2.0", + id: testId, + method: testMethod, + params: testParams, + }) + ); + + expect(handlerExtraArg).toEqual( + expect.objectContaining({ + signal: expect.any(AbortSignal), + sessionId: transport.sessionId, + }) + ); + }); + }); + }); describe("mergeCapabilities", () => { diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index a6e47184..e1bda8b5 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -1,4 +1,4 @@ -import { ZodLiteral, ZodObject, ZodType, z } from "zod"; +import { ZodLiteral, ZodObject, ZodType, z, ZodTypeAny } from "zod"; import { CancelledNotificationSchema, ClientCapabilities, @@ -15,6 +15,7 @@ import { ProgressNotificationSchema, Request, RequestId, + RequestIdSchema, Result, ServerCapabilities, } from "../types.js"; @@ -170,9 +171,21 @@ export abstract class Protocol< controller?.abort(notification.params.reason); }); - this.setNotificationHandler(ProgressNotificationSchema, (notification) => { - this._onprogress(notification as unknown as ProgressNotification); - }); + // Register the internal _onprogress handler DIRECTLY, bypassing the + // complex parsing wrapper logic added for general handlers, + // as we know the exact structure we need for progress notifications. + this._notificationHandlers.set( + ProgressNotificationSchema.shape.method.value, + async (rawNotification: JSONRPCNotification) => { + try { + // Directly parse the known ProgressNotificationSchema + const parsedNotification = ProgressNotificationSchema.parse(rawNotification); + this._onprogress(parsedNotification); + } catch (e) { + this._onerror(new Error(`Invalid progress notification structure: ${e instanceof Error ? e.message : String(e)}. Notification ignored. Raw: ${JSON.stringify(rawNotification)}`)); + } + } + ); this.setRequestHandler( PingRequestSchema, @@ -571,6 +584,9 @@ export abstract class Protocol< setRequestHandler< T extends ZodObject<{ method: ZodLiteral; + jsonrpc?: ZodLiteral<"2.0">; + id?: z.ZodUnion<[z.ZodString, z.ZodNumber]>; + params?: ZodTypeAny; }>, >( requestSchema: T, @@ -581,9 +597,31 @@ export abstract class Protocol< ): void { const method = requestSchema.shape.method.value; this.assertRequestHandlerCapability(method); - this._requestHandlers.set(method, (request, extra) => - Promise.resolve(handler(requestSchema.parse(request), extra)), - ); + + this._requestHandlers.set(method, async (rawRequest: JSONRPCRequest, extra: RequestHandlerExtra): Promise => { + const parsingSchemaDefinition = { + jsonrpc: z.literal("2.0" as const), + id: RequestIdSchema, + method: z.literal(method), + params: requestSchema.shape.params + ? requestSchema.shape.params + : z.unknown().optional(), + }; + const finalParsingSchema = z.object(parsingSchemaDefinition).passthrough(); + + let parsedRequest: z.infer; + try { + parsedRequest = await finalParsingSchema.parseAsync(rawRequest) as z.infer; + } catch (e) { + if (e instanceof z.ZodError) { + const errorDetails = e.errors.map(err => `${err.path.join('.') || 'params'}: ${err.message}`).join('; '); + throw new McpError(ErrorCode.InvalidParams, `Invalid request structure for method ${method}: ${errorDetails}`, e.flatten()); + } + throw e; + } + + return await Promise.resolve(handler(parsedRequest, extra)); + }); } /** @@ -612,16 +650,39 @@ export abstract class Protocol< setNotificationHandler< T extends ZodObject<{ method: ZodLiteral; + jsonrpc?: ZodLiteral<"2.0">; + params?: ZodTypeAny; }>, >( notificationSchema: T, handler: (notification: z.infer) => void | Promise, ): void { - this._notificationHandlers.set( - notificationSchema.shape.method.value, - (notification) => - Promise.resolve(handler(notificationSchema.parse(notification))), - ); + const method = notificationSchema.shape.method.value; + + this._notificationHandlers.set(method, async (rawNotification: JSONRPCNotification): Promise => { + const parsingSchemaDefinition = { + jsonrpc: z.literal("2.0" as const), + method: z.literal(method), + params: notificationSchema.shape.params + ? notificationSchema.shape.params + : z.unknown().optional(), + }; + const finalParsingSchema = z.object(parsingSchemaDefinition).passthrough(); + + let parsedNotification: z.infer; + try { + parsedNotification = await finalParsingSchema.parseAsync(rawNotification) as z.infer; + } catch (e) { + if (e instanceof z.ZodError) { + this._onerror(new Error(`Invalid notification structure for ${method}: ${e.message}. Notification ignored.`)); + return; + } + this._onerror(new Error(`Unexpected error parsing notification ${method}: ${e instanceof Error ? e.message : String(e)}`)); + return; + } + + await Promise.resolve(handler(parsedNotification)); + }); } /**