diff --git a/src/inMemory.ts b/src/inMemory.ts index 2763f38c..106a9e7e 100644 --- a/src/inMemory.ts +++ b/src/inMemory.ts @@ -11,6 +11,7 @@ export class InMemoryTransport implements Transport { onclose?: () => void; onerror?: (error: Error) => void; onmessage?: (message: JSONRPCMessage) => void; + sessionId?: string; /** * Creates a pair of linked in-memory transports that can communicate with each other. One should be passed to a Client and one to a Server. diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index 8f9bfa77..2e91a568 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -323,6 +323,59 @@ describe("tool()", () => { mcpServer.tool("tool2", () => ({ content: [] })); }); + test("should pass sessionId to tool callback via RequestHandlerExtra", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + tools: {}, + }, + }, + ); + + let receivedSessionId: string | undefined; + mcpServer.tool("test-tool", async (extra) => { + receivedSessionId = extra.sessionId; + return { + content: [ + { + type: "text", + text: "Test response", + }, + ], + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + // Set a test sessionId on the server transport + serverTransport.sessionId = "test-session-123"; + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + await client.request( + { + method: "tools/call", + params: { + name: "test-tool", + }, + }, + CallToolResultSchema, + ); + + expect(receivedSessionId).toBe("test-session-123"); + }); + test("should allow client to call server tools", async () => { const mcpServer = new McpServer({ name: "test server", diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 97213bf0..a5b6ad51 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -88,6 +88,11 @@ export type RequestHandlerExtra = { * An abort signal used to communicate if the request was cancelled from the sender's side. */ signal: AbortSignal; + + /** + * The session ID from the transport, if available. + */ + sessionId?: string; }; /** @@ -307,9 +312,15 @@ export abstract class Protocol< const abortController = new AbortController(); this._requestHandlerAbortControllers.set(request.id, abortController); + // Create extra object with both abort signal and sessionId from transport + const extra: RequestHandlerExtra = { + signal: abortController.signal, + sessionId: this._transport?.sessionId, + }; + // Starting with Promise.resolve() puts any synchronous errors into the monad as well. Promise.resolve() - .then(() => handler(request, { signal: abortController.signal })) + .then(() => handler(request, extra)) .then( (result) => { if (abortController.signal.aborted) { diff --git a/src/shared/transport.ts b/src/shared/transport.ts index 5843cf00..b80e2a51 100644 --- a/src/shared/transport.ts +++ b/src/shared/transport.ts @@ -41,4 +41,9 @@ export interface Transport { * Callback for when a message (request or response) is received over the connection. */ onmessage?: (message: JSONRPCMessage) => void; + + /** + * The session ID generated for this connection. + */ + sessionId?: string; }