From cc9ae5b4ac288f2d1586e07924f7a8b8da17ee2d Mon Sep 17 00:00:00 2001 From: Geng Yan Date: Fri, 28 Mar 2025 09:00:07 +0800 Subject: [PATCH 01/17] =?UTF-8?q?=E2=9C=A8=20feat:=20StreamableHTTPServerT?= =?UTF-8?q?ransport=20implement?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/server/streamable-http.test.ts | 427 +++++++++++++++++++++++++++++ src/server/streamable-http.ts | 388 ++++++++++++++++++++++++++ 2 files changed, 815 insertions(+) create mode 100644 src/server/streamable-http.test.ts create mode 100644 src/server/streamable-http.ts diff --git a/src/server/streamable-http.test.ts b/src/server/streamable-http.test.ts new file mode 100644 index 00000000..8dd5dea7 --- /dev/null +++ b/src/server/streamable-http.test.ts @@ -0,0 +1,427 @@ +import { IncomingMessage, ServerResponse } from "node:http"; +import { StreamableHTTPServerTransport } from "./streamable-http.js"; +import { JSONRPCMessage } from "../types.js"; +import { Readable } from "node:stream"; + +// Mock IncomingMessage +function createMockRequest(options: { + method: string; + headers: Record; + body?: string; +}): IncomingMessage { + const readable = new Readable(); + readable._read = () => {}; + if (options.body) { + readable.push(options.body); + readable.push(null); + } + + return Object.assign(readable, { + method: options.method, + headers: options.headers, + }) as IncomingMessage; +} + +// Mock ServerResponse +function createMockResponse(): jest.Mocked { + const response = { + writeHead: jest.fn().mockReturnThis(), + write: jest.fn().mockReturnThis(), + end: jest.fn().mockReturnThis(), + on: jest.fn().mockReturnThis(), + emit: jest.fn().mockReturnThis(), + getHeader: jest.fn(), + setHeader: jest.fn(), + } as unknown as jest.Mocked; + return response; +} + +describe("StreamableHTTPServerTransport", () => { + const endpoint = "/mcp"; + let transport: StreamableHTTPServerTransport; + let mockResponse: jest.Mocked; + + beforeEach(() => { + transport = new StreamableHTTPServerTransport(endpoint); + mockResponse = createMockResponse(); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + describe("Session Management", () => { + it("should generate a valid session ID", () => { + expect(transport.sessionId).toBeTruthy(); + expect(typeof transport.sessionId).toBe("string"); + }); + + it("should include session ID in response headers", async () => { + const req = createMockRequest({ + method: "GET", + headers: { + accept: "text/event-stream" + }, + }); + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "Mcp-Session-Id": transport.sessionId, + }) + ); + }); + + it("should reject invalid session ID", async () => { + const req = createMockRequest({ + method: "GET", + headers: { + "mcp-session-id": "invalid-session-id", + }, + }); + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(404); + // check if the error response is a valid JSON-RPC error format + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"error"')); + }); + }); + + describe("Request Handling", () => { + it("should reject GET requests without Accept: text/event-stream header", async () => { + const req = createMockRequest({ + method: "GET", + headers: {}, + }); + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(406); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); + }); + + it("should properly handle GET requests with Accept header and establish SSE connection", async () => { + const req = createMockRequest({ + method: "GET", + headers: { + accept: "text/event-stream", + }, + }); + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }) + ); + }); + + it("should reject POST requests without proper Accept header", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: 1, + }; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + }, + body: JSON.stringify(message), + }); + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(406); + }); + + it("should properly handle JSON-RPC request messages in POST requests", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: 1, + }; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "text/event-stream", + }, + body: JSON.stringify(message), + }); + + const onMessageMock = jest.fn(); + transport.onmessage = onMessageMock; + + await transport.handleRequest(req, mockResponse); + + expect(onMessageMock).toHaveBeenCalledWith(message); + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "Content-Type": "text/event-stream", + }) + ); + }); + + it("should properly handle JSON-RPC notification or response messages in POST requests", async () => { + const notification: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + }; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(notification), + }); + + const onMessageMock = jest.fn(); + transport.onmessage = onMessageMock; + + await transport.handleRequest(req, mockResponse); + + expect(onMessageMock).toHaveBeenCalledWith(notification); + expect(mockResponse.writeHead).toHaveBeenCalledWith(202); + }); + + it("should handle batch messages properly", async () => { + const batchMessages: JSONRPCMessage[] = [ + { jsonrpc: "2.0", method: "test1", params: {} }, + { jsonrpc: "2.0", method: "test2", params: {} }, + ]; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json", + }, + body: JSON.stringify(batchMessages), + }); + + const onMessageMock = jest.fn(); + transport.onmessage = onMessageMock; + + await transport.handleRequest(req, mockResponse); + + expect(onMessageMock).toHaveBeenCalledTimes(2); + expect(mockResponse.writeHead).toHaveBeenCalledWith(202); + }); + + it("should reject unsupported Content-Type", async () => { + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "text/plain", + "accept": "application/json", + }, + body: "test", + }); + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(415); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); + }); + + it("should properly handle DELETE requests and close session", async () => { + const req = createMockRequest({ + method: "DELETE", + headers: {}, + }); + + const onCloseMock = jest.fn(); + transport.onclose = onCloseMock; + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(200); + expect(onCloseMock).toHaveBeenCalled(); + }); + }); + + describe("Message Replay", () => { + it("should replay messages after specified Last-Event-ID", async () => { + // Establish first connection with Accept header + const req1 = createMockRequest({ + method: "GET", + headers: { + "accept": "text/event-stream" + }, + }); + await transport.handleRequest(req1, mockResponse); + + // Send a message to first connection + const message1: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test1", + params: {}, + id: 1 + }; + + await transport.send(message1); + + // Get message ID (captured from write call) + const writeCall = mockResponse.write.mock.calls[0][0] as string; + const idMatch = writeCall.match(/id: ([a-f0-9-]+)/); + if (!idMatch) { + throw new Error("Message ID not found in write call"); + } + const lastEventId = idMatch[1]; + + // Create a second connection with last-event-id + const mockResponse2 = createMockResponse(); + const req2 = createMockRequest({ + method: "GET", + headers: { + "accept": "text/event-stream", + "last-event-id": lastEventId, + }, + }); + + await transport.handleRequest(req2, mockResponse2); + + // Send a second message + const message2: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test2", + params: {}, + id: 2 + }; + + await transport.send(message2); + + // Verify the second message was received by both connections + expect(mockResponse.write).toHaveBeenCalledWith( + expect.stringContaining(JSON.stringify(message1)) + ); + expect(mockResponse2.write).toHaveBeenCalledWith( + expect.stringContaining(JSON.stringify(message2)) + ); + }); + }); + + describe("Message Targeting", () => { + it("should send response messages to the connection that sent the request", async () => { + // Create two connections + const mockResponse1 = createMockResponse(); + const req1 = createMockRequest({ + method: "GET", + headers: { + "accept": "text/event-stream", + }, + }); + await transport.handleRequest(req1, mockResponse1); + + const mockResponse2 = createMockResponse(); + const req2 = createMockRequest({ + method: "GET", + headers: { + "accept": "text/event-stream", + }, + }); + await transport.handleRequest(req2, mockResponse2); + + // Send a request through the first connection + const requestMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: "test-id", + }; + + const reqPost = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "text/event-stream", + }, + body: JSON.stringify(requestMessage), + }); + + await transport.handleRequest(reqPost, mockResponse1); + + // Send a response with matching ID + const responseMessage: JSONRPCMessage = { + jsonrpc: "2.0", + result: { success: true }, + id: "test-id", + }; + + await transport.send(responseMessage); + + // Verify response was sent to the right connection + expect(mockResponse1.write).toHaveBeenCalledWith( + expect.stringContaining(JSON.stringify(responseMessage)) + ); + + // Check if write was called with this exact message on the second connection + const writeCallsOnSecondConn = mockResponse2.write.mock.calls.filter(call => + typeof call[0] === 'string' && call[0].includes(JSON.stringify(responseMessage)) + ); + + // Verify the response wasn't broadcast to all connections + expect(writeCallsOnSecondConn.length).toBe(0); + }); + }); + + describe("Error Handling", () => { + it("should handle invalid JSON data", async () => { + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json", + }, + body: "invalid json", + }); + + const onErrorMock = jest.fn(); + transport.onerror = onErrorMock; + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(400); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"code":-32700')); + expect(onErrorMock).toHaveBeenCalled(); + }); + + it("should handle invalid JSON-RPC messages", async () => { + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json", + }, + body: JSON.stringify({ invalid: "message" }), + }); + + const onErrorMock = jest.fn(); + transport.onerror = onErrorMock; + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(400); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); + expect(onErrorMock).toHaveBeenCalled(); + }); + }); +}); \ No newline at end of file diff --git a/src/server/streamable-http.ts b/src/server/streamable-http.ts new file mode 100644 index 00000000..2567240c --- /dev/null +++ b/src/server/streamable-http.ts @@ -0,0 +1,388 @@ +import { randomUUID } from "node:crypto"; +import { IncomingMessage, ServerResponse } from "node:http"; +import { Transport } from "../shared/transport.js"; +import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; +import getRawBody from "raw-body"; +import contentType from "content-type"; + +const MAXIMUM_MESSAGE_SIZE = "4mb"; + +interface StreamConnection { + response: ServerResponse; + lastEventId?: string; + messages: Array<{ + id: string; + message: JSONRPCMessage; + }>; + // mark this connection as a response to a specific request + requestId?: string | null; +} + +/** + * Server transport for Streamable HTTP: this implements the MCP Streamable HTTP transport specification. + * It supports both SSE streaming and direct HTTP responses, with session management and message resumability. + */ +export class StreamableHTTPServerTransport implements Transport { + private _connections: Map = new Map(); + private _sessionId: string; + private _messageHistory: Map = new Map(); + private _started: boolean = false; + private _requestConnections: Map = new Map(); // request ID to connection ID mapping + + onclose?: () => void; + onerror?: (error: Error) => void; + onmessage?: (message: JSONRPCMessage) => void; + + constructor(private _endpoint: string) { + this._sessionId = randomUUID(); + } + + /** + * Starts the transport. This is required by the Transport interface but is a no-op + * for the Streamable HTTP transport as connections are managed per-request. + */ + async start(): Promise { + if (this._started) { + throw new Error("Transport already started"); + } + this._started = true; + } + + /** + * Handles an incoming HTTP request, whether GET or POST + */ + async handleRequest(req: IncomingMessage, res: ServerResponse): Promise { + // validate the session ID + const sessionId = req.headers["mcp-session-id"]; + if (sessionId && (Array.isArray(sessionId) ? sessionId[0] : sessionId) !== this._sessionId) { + res.writeHead(404).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32001, + message: "Session not found" + }, + id: null + })); + return; + } + + if (req.method === "GET") { + await this.handleGetRequest(req, res); + } else if (req.method === "POST") { + await this.handlePostRequest(req, res); + } else if (req.method === "DELETE") { + await this.handleDeleteRequest(req, res); + } else { + res.writeHead(405).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Method not allowed" + }, + id: null + })); + } + } + + /** + * Handles GET requests to establish SSE connections + */ + private async handleGetRequest(req: IncomingMessage, res: ServerResponse): Promise { + // validate the Accept header + const acceptHeader = req.headers.accept; + if (!acceptHeader || !acceptHeader.includes("text/event-stream")) { + res.writeHead(406).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Not Acceptable: Client must accept text/event-stream" + }, + id: null + })); + return; + } + + const connectionId = randomUUID(); + const lastEventId = req.headers["last-event-id"]; + const lastEventIdStr = Array.isArray(lastEventId) ? lastEventId[0] : lastEventId; + + res.writeHead(200, { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + "Mcp-Session-Id": this._sessionId, + }); + + const connection: StreamConnection = { + response: res, + lastEventId: lastEventIdStr, + messages: [], + }; + + this._connections.set(connectionId, connection); + + // if there is a Last-Event-ID, replay messages on this connection + if (lastEventIdStr) { + this.replayMessages(connectionId, lastEventIdStr); + } + + res.on("close", () => { + this._connections.delete(connectionId); + // remove all request mappings associated with this connection + for (const [reqId, connId] of this._requestConnections.entries()) { + if (connId === connectionId) { + this._requestConnections.delete(reqId); + } + } + if (this._connections.size === 0) { + this.onclose?.(); + } + }); + } + + /** + * Handles POST requests containing JSON-RPC messages + */ + private async handlePostRequest(req: IncomingMessage, res: ServerResponse): Promise { + try { + // validate the Accept header + const acceptHeader = req.headers.accept; + if (!acceptHeader || + (!acceptHeader.includes("application/json") && !acceptHeader.includes("text/event-stream"))) { + res.writeHead(406).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Not Acceptable: Client must accept application/json and/or text/event-stream" + }, + id: null + })); + return; + } + + const ct = req.headers["content-type"]; + if (!ct || !ct.includes("application/json")) { + res.writeHead(415).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Unsupported Media Type: Content-Type must be application/json" + }, + id: null + })); + return; + } + + const parsedCt = contentType.parse(ct); + const body = await getRawBody(req, { + limit: MAXIMUM_MESSAGE_SIZE, + encoding: parsedCt.parameters.charset ?? "utf-8", + }); + + const rawMessage = JSON.parse(body.toString()); + let messages: JSONRPCMessage[]; + + // handle batch and single messages + if (Array.isArray(rawMessage)) { + messages = rawMessage.map(msg => JSONRPCMessageSchema.parse(msg)); + } else { + messages = [JSONRPCMessageSchema.parse(rawMessage)]; + } + + // check if it contains requests + const hasRequests = messages.some(msg => 'method' in msg && 'id' in msg); + const hasOnlyNotificationsOrResponses = messages.every(msg => + ('method' in msg && !('id' in msg)) || ('result' in msg || 'error' in msg)); + + if (hasOnlyNotificationsOrResponses) { + // if it only contains notifications or responses, return 202 + res.writeHead(202).end(); + + // handle each message + for (const message of messages) { + this.onmessage?.(message); + } + } else if (hasRequests) { + // if it contains requests, you can choose to return an SSE stream or a JSON response + const useSSE = acceptHeader.includes("text/event-stream"); + + if (useSSE) { + res.writeHead(200, { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + "Mcp-Session-Id": this._sessionId, + }); + + const connectionId = randomUUID(); + const connection: StreamConnection = { + response: res, + messages: [], + }; + + this._connections.set(connectionId, connection); + + // map each request to a connection ID + for (const message of messages) { + if ('method' in message && 'id' in message) { + this._requestConnections.set(String(message.id), connectionId); + } + this.onmessage?.(message); + } + + res.on("close", () => { + this._connections.delete(connectionId); + // remove all request mappings associated with this connection + for (const [reqId, connId] of this._requestConnections.entries()) { + if (connId === connectionId) { + this._requestConnections.delete(reqId); + } + } + if (this._connections.size === 0) { + this.onclose?.(); + } + }); + } else { + // use direct JSON response + res.writeHead(200, { + "Content-Type": "application/json", + "Mcp-Session-Id": this._sessionId, + }); + + // handle each message + for (const message of messages) { + this.onmessage?.(message); + } + + res.end(); + } + } + } catch (error) { + // return JSON-RPC formatted error + res.writeHead(400).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32700, + message: "Parse error", + data: String(error) + }, + id: null + })); + this.onerror?.(error as Error); + } + } + + /** + * Handles DELETE requests to terminate sessions + */ + private async handleDeleteRequest(req: IncomingMessage, res: ServerResponse): Promise { + await this.close(); + res.writeHead(200).end(); + } + + /** + * Replays messages after the specified event ID for a specific connection + */ + private replayMessages(connectionId: string, lastEventId: string): void { + if (!lastEventId) return; + + // only replay messages that should be sent on this connection + const messages = Array.from(this._messageHistory.entries()) + .filter(([id, { connectionId: msgConnId }]) => + id > lastEventId && + (!msgConnId || msgConnId === connectionId)) // only replay messages that are not specified to a connection or specified to the current connection + .sort(([a], [b]) => a.localeCompare(b)); + + const connection = this._connections.get(connectionId); + if (!connection) return; + + for (const [id, { message }] of messages) { + connection.response.write( + `id: ${id}\nevent: message\ndata: ${JSON.stringify(message)}\n\n` + ); + } + } + + async close(): Promise { + for (const connection of this._connections.values()) { + connection.response.end(); + } + this._connections.clear(); + this._messageHistory.clear(); + this._requestConnections.clear(); + this.onclose?.(); + } + + async send(message: JSONRPCMessage): Promise { + if (this._connections.size === 0) { + throw new Error("No active connections"); + } + + let targetConnectionId = ""; + + // if it is a response, find the corresponding request connection + if ('id' in message && ('result' in message || 'error' in message)) { + const connId = this._requestConnections.get(String(message.id)); + + // if the corresponding connection is not found, the connection may be disconnected + if (!connId || !this._connections.has(connId)) { + // select an available connection + const firstConnId = this._connections.keys().next().value; + if (firstConnId) { + targetConnectionId = firstConnId; + } else { + throw new Error("No available connections"); + } + } else { + targetConnectionId = connId; + } + } else { + // for other messages, select an available connection + const firstConnId = this._connections.keys().next().value; + if (firstConnId) { + targetConnectionId = firstConnId; + } else { + throw new Error("No available connections"); + } + } + + const messageId = randomUUID(); + this._messageHistory.set(messageId, { + message, + connectionId: targetConnectionId + }); + + // keep the message history in a reasonable range + if (this._messageHistory.size > 1000) { + const oldestKey = Array.from(this._messageHistory.keys())[0]; + this._messageHistory.delete(oldestKey); + } + + // send the message to all active connections + for (const [connId, connection] of this._connections.entries()) { + // if it is a response message, only send to the target connection + if ('id' in message && ('result' in message || 'error' in message)) { + if (connId === targetConnectionId) { + connection.response.write( + `id: ${messageId}\nevent: message\ndata: ${JSON.stringify(message)}\n\n` + ); + } + } else { + // for other messages, send to all connections + connection.response.write( + `id: ${messageId}\nevent: message\ndata: ${JSON.stringify(message)}\n\n` + ); + } + } + } + + /** + * Returns the session ID for this transport + */ + get sessionId(): string { + return this._sessionId; + } +} \ No newline at end of file From 970cf9d5b168717459733d43b43e889121e501e6 Mon Sep 17 00:00:00 2001 From: Geng Yan Date: Tue, 1 Apr 2025 10:58:37 +0800 Subject: [PATCH 02/17] fix: Use lowercase header names --- src/server/streamable-http.test.ts | 2 +- src/server/streamable-http.ts | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/server/streamable-http.test.ts b/src/server/streamable-http.test.ts index 8dd5dea7..fa5140c0 100644 --- a/src/server/streamable-http.test.ts +++ b/src/server/streamable-http.test.ts @@ -69,7 +69,7 @@ describe("StreamableHTTPServerTransport", () => { expect(mockResponse.writeHead).toHaveBeenCalledWith( 200, expect.objectContaining({ - "Mcp-Session-Id": transport.sessionId, + "mcp-session-id": transport.sessionId, }) ); }); diff --git a/src/server/streamable-http.ts b/src/server/streamable-http.ts index 2567240c..bf0096d1 100644 --- a/src/server/streamable-http.ts +++ b/src/server/streamable-http.ts @@ -113,7 +113,7 @@ export class StreamableHTTPServerTransport implements Transport { "Content-Type": "text/event-stream", "Cache-Control": "no-cache", Connection: "keep-alive", - "Mcp-Session-Id": this._sessionId, + "mcp-session-id": this._sessionId, }); const connection: StreamConnection = { @@ -214,7 +214,7 @@ export class StreamableHTTPServerTransport implements Transport { "Content-Type": "text/event-stream", "Cache-Control": "no-cache", Connection: "keep-alive", - "Mcp-Session-Id": this._sessionId, + "mcp-session-id": this._sessionId, }); const connectionId = randomUUID(); @@ -249,7 +249,7 @@ export class StreamableHTTPServerTransport implements Transport { // use direct JSON response res.writeHead(200, { "Content-Type": "application/json", - "Mcp-Session-Id": this._sessionId, + "mcp-session-id": this._sessionId, }); // handle each message From 4c7c434e7c2c8ae9d97334640f48b18c5930a9c0 Mon Sep 17 00:00:00 2001 From: Geng Yan Date: Wed, 2 Apr 2025 08:47:35 +0800 Subject: [PATCH 03/17] =?UTF-8?q?=E2=9C=A8=20feat:=20try=20add=20stateless?= =?UTF-8?q?=20mod?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/server/streamable-http.test.ts | 268 ++++++++++++++++++++++++++++- src/server/streamable-http.ts | 131 +++++++++++--- 2 files changed, 368 insertions(+), 31 deletions(-) diff --git a/src/server/streamable-http.test.ts b/src/server/streamable-http.test.ts index fa5140c0..960d3800 100644 --- a/src/server/streamable-http.test.ts +++ b/src/server/streamable-http.test.ts @@ -57,11 +57,24 @@ describe("StreamableHTTPServerTransport", () => { }); it("should include session ID in response headers", async () => { + // Use POST with initialize method to avoid session ID requirement + const initializeMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + const req = createMockRequest({ - method: "GET", + method: "POST", headers: { - accept: "text/event-stream" + "content-type": "application/json", + "accept": "application/json", }, + body: JSON.stringify(initializeMessage), }); await transport.handleRequest(req, mockResponse); @@ -79,6 +92,7 @@ describe("StreamableHTTPServerTransport", () => { method: "GET", headers: { "mcp-session-id": "invalid-session-id", + "accept": "text/event-stream" }, }); @@ -89,13 +103,241 @@ describe("StreamableHTTPServerTransport", () => { expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"error"')); }); + + it("should reject non-initialization requests without session ID with 400 Bad Request", async () => { + const req = createMockRequest({ + method: "GET", + headers: { + accept: "text/event-stream", + // No mcp-session-id header + }, + }); + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(400); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"message":"Bad Request: Mcp-Session-Id header is required"')); + }); + + it("should always include session ID in initialization response even in stateless mode", async () => { + // Create a stateless transport for this test + const statelessTransport = new StreamableHTTPServerTransport(endpoint, { enableSessionManagement: false }); + + // Create an initialization request + const initializeMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json", + }, + body: JSON.stringify(initializeMessage), + }); + + await statelessTransport.handleRequest(req, mockResponse); + + // In stateless mode, session ID should also be included for initialize responses + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "mcp-session-id": statelessTransport.sessionId, + }) + ); + }); + }); + + describe("Stateless Mode", () => { + let statelessTransport: StreamableHTTPServerTransport; + let mockResponse: jest.Mocked; + + beforeEach(() => { + statelessTransport = new StreamableHTTPServerTransport(endpoint, { enableSessionManagement: false }); + mockResponse = createMockResponse(); + }); + + it("should not include session ID in response headers when in stateless mode", async () => { + // Use a non-initialization request + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: 1, + }; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json", + }, + body: JSON.stringify(message), + }); + + await statelessTransport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalled(); + // Extract the headers from writeHead call + const headers = mockResponse.writeHead.mock.calls[0][1]; + expect(headers).not.toHaveProperty("mcp-session-id"); + }); + + it("should not validate session ID in stateless mode", async () => { + const req = createMockRequest({ + method: "GET", + headers: { + accept: "text/event-stream", + "mcp-session-id": "invalid-session-id", // This would cause a 404 in stateful mode + }, + }); + + await statelessTransport.handleRequest(req, mockResponse); + + // Should still get 200 OK, not 404 Not Found + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.not.objectContaining({ + "mcp-session-id": expect.anything(), + }) + ); + }); + + it("should handle POST requests without session validation in stateless mode", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: 1, + }; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json", + "mcp-session-id": "non-existent-session-id", // This would be rejected in stateful mode + }, + body: JSON.stringify(message), + }); + + const onMessageMock = jest.fn(); + statelessTransport.onmessage = onMessageMock; + + await statelessTransport.handleRequest(req, mockResponse); + + // Message should be processed despite invalid session ID + expect(onMessageMock).toHaveBeenCalledWith(message); + }); + + it("should work with a mix of requests with and without session IDs in stateless mode", async () => { + // First request without session ID + const req1 = createMockRequest({ + method: "GET", + headers: { + accept: "text/event-stream", + }, + }); + + await statelessTransport.handleRequest(req1, mockResponse); + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "Content-Type": "text/event-stream", + }) + ); + + // Reset mock for second request + mockResponse.writeHead.mockClear(); + + // Second request with a session ID (which would be invalid in stateful mode) + const req2 = createMockRequest({ + method: "GET", + headers: { + accept: "text/event-stream", + "mcp-session-id": "some-random-session-id", + }, + }); + + await statelessTransport.handleRequest(req2, mockResponse); + + // Should still succeed + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "Content-Type": "text/event-stream", + }) + ); + }); + + it("should handle initialization requests properly in both modes", async () => { + // Initialize message that would typically be sent during initialization + const initializeMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + // Test stateful transport (default) + const statefulReq = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json", + }, + body: JSON.stringify(initializeMessage), + }); + + await transport.handleRequest(statefulReq, mockResponse); + + // In stateful mode, session ID should be included in the response header + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "mcp-session-id": transport.sessionId, + }) + ); + + // Reset mocks for stateless test + mockResponse.writeHead.mockClear(); + + // Test stateless transport + const statelessReq = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json", + }, + body: JSON.stringify(initializeMessage), + }); + + await statelessTransport.handleRequest(statelessReq, mockResponse); + + // In stateless mode, session ID should also be included for initialize responses + const headers = mockResponse.writeHead.mock.calls[0][1]; + expect(headers).toHaveProperty("mcp-session-id", statelessTransport.sessionId); + }); }); describe("Request Handling", () => { it("should reject GET requests without Accept: text/event-stream header", async () => { const req = createMockRequest({ method: "GET", - headers: {}, + headers: { + "mcp-session-id": transport.sessionId, + }, }); await transport.handleRequest(req, mockResponse); @@ -108,7 +350,8 @@ describe("StreamableHTTPServerTransport", () => { const req = createMockRequest({ method: "GET", headers: { - accept: "text/event-stream", + "accept": "text/event-stream", + "mcp-session-id": transport.sessionId, }, }); @@ -127,7 +370,7 @@ describe("StreamableHTTPServerTransport", () => { it("should reject POST requests without proper Accept header", async () => { const message: JSONRPCMessage = { jsonrpc: "2.0", - method: "test", + method: "initialize", // Use initialize to bypass session ID check params: {}, id: 1, }; @@ -148,7 +391,7 @@ describe("StreamableHTTPServerTransport", () => { it("should properly handle JSON-RPC request messages in POST requests", async () => { const message: JSONRPCMessage = { jsonrpc: "2.0", - method: "test", + method: "initialize", // Use initialize to bypass session ID check params: {}, id: 1, }; @@ -188,6 +431,7 @@ describe("StreamableHTTPServerTransport", () => { headers: { "content-type": "application/json", "accept": "application/json, text/event-stream", + "mcp-session-id": transport.sessionId, }, body: JSON.stringify(notification), }); @@ -212,6 +456,7 @@ describe("StreamableHTTPServerTransport", () => { headers: { "content-type": "application/json", "accept": "application/json", + "mcp-session-id": transport.sessionId, }, body: JSON.stringify(batchMessages), }); @@ -231,6 +476,7 @@ describe("StreamableHTTPServerTransport", () => { headers: { "content-type": "text/plain", "accept": "application/json", + "mcp-session-id": transport.sessionId, }, body: "test", }); @@ -244,7 +490,9 @@ describe("StreamableHTTPServerTransport", () => { it("should properly handle DELETE requests and close session", async () => { const req = createMockRequest({ method: "DELETE", - headers: {}, + headers: { + "mcp-session-id": transport.sessionId, + }, }); const onCloseMock = jest.fn(); @@ -259,11 +507,12 @@ describe("StreamableHTTPServerTransport", () => { describe("Message Replay", () => { it("should replay messages after specified Last-Event-ID", async () => { - // Establish first connection with Accept header + // Establish first connection with Accept header and session ID const req1 = createMockRequest({ method: "GET", headers: { - "accept": "text/event-stream" + "accept": "text/event-stream", + "mcp-session-id": transport.sessionId }, }); await transport.handleRequest(req1, mockResponse); @@ -293,6 +542,7 @@ describe("StreamableHTTPServerTransport", () => { headers: { "accept": "text/event-stream", "last-event-id": lastEventId, + "mcp-session-id": transport.sessionId }, }); diff --git a/src/server/streamable-http.ts b/src/server/streamable-http.ts index bf0096d1..73eae32b 100644 --- a/src/server/streamable-http.ts +++ b/src/server/streamable-http.ts @@ -18,9 +18,44 @@ interface StreamConnection { requestId?: string | null; } +/** + * Configuration options for StreamableHTTPServerTransport + */ +export interface StreamableHTTPServerTransportOptions { + /** + * Whether to enable session management through mcp-session-id headers + * When set to false, the transport operates in stateless mode without session validation + * @default true + */ + enableSessionManagement?: boolean; +} + /** * Server transport for Streamable HTTP: this implements the MCP Streamable HTTP transport specification. * It supports both SSE streaming and direct HTTP responses, with session management and message resumability. + * + * Usage example: + * + * ```typescript + * // Stateful mode (default) - with session management + * const statefulTransport = new StreamableHTTPServerTransport("/mcp"); + * + * // Stateless mode - without session management + * const statelessTransport = new StreamableHTTPServerTransport("/mcp", { + * enableSessionManagement: false + * }); + * ``` + * + * In stateful mode: + * - Session ID is generated and included in response headers + * - Session ID is always included in initialization responses + * - Requests with invalid session IDs are rejected with 404 Not Found + * - Non-initialization requests without a session ID are rejected with 400 Bad Request + * - State is maintained in-memory (connections, message history) + * + * In stateless mode: + * - Session ID is only included in initialization responses + * - No session validation is performed */ export class StreamableHTTPServerTransport implements Transport { private _connections: Map = new Map(); @@ -31,13 +66,15 @@ export class StreamableHTTPServerTransport implements Transport { }> = new Map(); private _started: boolean = false; private _requestConnections: Map = new Map(); // request ID to connection ID mapping + private _enableSessionManagement: boolean; onclose?: () => void; onerror?: (error: Error) => void; onmessage?: (message: JSONRPCMessage) => void; - constructor(private _endpoint: string) { + constructor(private _endpoint: string, options?: StreamableHTTPServerTransportOptions) { this._sessionId = randomUUID(); + this._enableSessionManagement = options?.enableSessionManagement !== false; } /** @@ -55,18 +92,41 @@ export class StreamableHTTPServerTransport implements Transport { * Handles an incoming HTTP request, whether GET or POST */ async handleRequest(req: IncomingMessage, res: ServerResponse): Promise { - // validate the session ID - const sessionId = req.headers["mcp-session-id"]; - if (sessionId && (Array.isArray(sessionId) ? sessionId[0] : sessionId) !== this._sessionId) { - res.writeHead(404).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32001, - message: "Session not found" - }, - id: null - })); - return; + // Only validate session ID for non-initialization requests when session management is enabled + if (this._enableSessionManagement) { + const sessionId = req.headers["mcp-session-id"]; + + // Check if this might be an initialization request + const isInitializationRequest = req.method === "POST" && + req.headers["content-type"]?.includes("application/json"); + + if (isInitializationRequest) { + // For POST requests with JSON content, we need to check if it's an initialization request + // This will be done in handlePostRequest, as we need to parse the body + // Continue processing normally + } else if (!sessionId) { + // Non-initialization requests without a session ID should return 400 Bad Request + res.writeHead(400).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Bad Request: Mcp-Session-Id header is required" + }, + id: null + })); + return; + } else if ((Array.isArray(sessionId) ? sessionId[0] : sessionId) !== this._sessionId) { + // Reject requests with invalid session ID with 404 Not Found + res.writeHead(404).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32001, + message: "Session not found" + }, + id: null + })); + return; + } } if (req.method === "GET") { @@ -109,12 +169,19 @@ export class StreamableHTTPServerTransport implements Transport { const lastEventId = req.headers["last-event-id"]; const lastEventIdStr = Array.isArray(lastEventId) ? lastEventId[0] : lastEventId; - res.writeHead(200, { + // Prepare response headers + const headers: Record = { "Content-Type": "text/event-stream", "Cache-Control": "no-cache", Connection: "keep-alive", - "mcp-session-id": this._sessionId, - }); + }; + + // Only include session ID header if session management is enabled + if (this._enableSessionManagement) { + headers["mcp-session-id"] = this._sessionId; + } + + res.writeHead(200, headers); const connection: StreamConnection = { response: res, @@ -192,6 +259,12 @@ export class StreamableHTTPServerTransport implements Transport { messages = [JSONRPCMessageSchema.parse(rawMessage)]; } + // Check if this is an initialization request + // https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/lifecycle/ + const isInitializationRequest = messages.some( + msg => 'method' in msg && msg.method === 'initialize' && 'id' in msg + ); + // check if it contains requests const hasRequests = messages.some(msg => 'method' in msg && 'id' in msg); const hasOnlyNotificationsOrResponses = messages.every(msg => @@ -210,12 +283,19 @@ export class StreamableHTTPServerTransport implements Transport { const useSSE = acceptHeader.includes("text/event-stream"); if (useSSE) { - res.writeHead(200, { + const headers: Record = { "Content-Type": "text/event-stream", "Cache-Control": "no-cache", Connection: "keep-alive", - "mcp-session-id": this._sessionId, - }); + }; + + // Only include session ID header if session management is enabled + // Always include session ID for initialization requests + if (this._enableSessionManagement || isInitializationRequest) { + headers["mcp-session-id"] = this._sessionId; + } + + res.writeHead(200, headers); const connectionId = randomUUID(); const connection: StreamConnection = { @@ -247,10 +327,17 @@ export class StreamableHTTPServerTransport implements Transport { }); } else { // use direct JSON response - res.writeHead(200, { + const headers: Record = { "Content-Type": "application/json", - "mcp-session-id": this._sessionId, - }); + }; + + // Only include session ID header if session management is enabled + // Always include session ID for initialization requests + if (this._enableSessionManagement || isInitializationRequest) { + headers["mcp-session-id"] = this._sessionId; + } + + res.writeHead(200, headers); // handle each message for (const message of messages) { From e3a61095aed1b967d9164cc321d7b2b122465779 Mon Sep 17 00:00:00 2001 From: Geng Yan Date: Wed, 2 Apr 2025 08:57:21 +0800 Subject: [PATCH 04/17] =?UTF-8?q?=E2=9C=A8=20feat:=20add=20pre-rarsed=20bo?= =?UTF-8?q?dy=20=20support?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/server/streamable-http.test.ts | 114 +++++++++++++++++++++++++++++ src/server/streamable-http.ts | 28 ++++--- 2 files changed, 133 insertions(+), 9 deletions(-) diff --git a/src/server/streamable-http.test.ts b/src/server/streamable-http.test.ts index 960d3800..13f70b59 100644 --- a/src/server/streamable-http.test.ts +++ b/src/server/streamable-http.test.ts @@ -674,4 +674,118 @@ describe("StreamableHTTPServerTransport", () => { expect(onErrorMock).toHaveBeenCalled(); }); }); + + describe("Handling Pre-Parsed Body", () => { + it("should accept pre-parsed request body", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "pre-parsed-test", + }; + + // Create a request without actual body content + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json", + }, + // No body provided here - it will be passed as parsedBody + }); + + const onMessageMock = jest.fn(); + transport.onmessage = onMessageMock; + + // Pass the pre-parsed body directly + await transport.handleRequest(req, mockResponse, message); + + // Verify the message was processed correctly + expect(onMessageMock).toHaveBeenCalledWith(message); + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "Content-Type": "application/json", + }) + ); + }); + + it("should handle pre-parsed batch messages", async () => { + const batchMessages: JSONRPCMessage[] = [ + { + jsonrpc: "2.0", + method: "method1", + params: { data: "test1" }, + id: "batch1" + }, + { + jsonrpc: "2.0", + method: "method2", + params: { data: "test2" }, + id: "batch2" + }, + ]; + + // Create a request without actual body content + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "text/event-stream", + "mcp-session-id": transport.sessionId, + }, + // No body provided here - it will be passed as parsedBody + }); + + const onMessageMock = jest.fn(); + transport.onmessage = onMessageMock; + + // Pass the pre-parsed body directly + await transport.handleRequest(req, mockResponse, batchMessages); + + // Should be called for each message in the batch + expect(onMessageMock).toHaveBeenCalledTimes(2); + expect(onMessageMock).toHaveBeenCalledWith(batchMessages[0]); + expect(onMessageMock).toHaveBeenCalledWith(batchMessages[1]); + }); + + it("should prefer pre-parsed body over request body", async () => { + const requestBodyMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "fromRequestBody", + params: {}, + id: "request-body", + }; + + const parsedBodyMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "fromParsedBody", + params: {}, + id: "parsed-body", + }; + + // Create a request with actual body content + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json", + }, + body: JSON.stringify(requestBodyMessage), + }); + + const onMessageMock = jest.fn(); + transport.onmessage = onMessageMock; + + // Pass the pre-parsed body directly + await transport.handleRequest(req, mockResponse, parsedBodyMessage); + + // Should use the parsed body instead of the request body + expect(onMessageMock).toHaveBeenCalledWith(parsedBodyMessage); + expect(onMessageMock).not.toHaveBeenCalledWith(requestBodyMessage); + }); + }); }); \ No newline at end of file diff --git a/src/server/streamable-http.ts b/src/server/streamable-http.ts index 73eae32b..bec0a1ad 100644 --- a/src/server/streamable-http.ts +++ b/src/server/streamable-http.ts @@ -44,6 +44,11 @@ export interface StreamableHTTPServerTransportOptions { * const statelessTransport = new StreamableHTTPServerTransport("/mcp", { * enableSessionManagement: false * }); + * + * // Using with pre-parsed request body + * app.post('/mcp', (req, res) => { + * transport.handleRequest(req, res, req.body); + * }); * ``` * * In stateful mode: @@ -91,7 +96,7 @@ export class StreamableHTTPServerTransport implements Transport { /** * Handles an incoming HTTP request, whether GET or POST */ - async handleRequest(req: IncomingMessage, res: ServerResponse): Promise { + async handleRequest(req: IncomingMessage, res: ServerResponse, parsedBody?: unknown): Promise { // Only validate session ID for non-initialization requests when session management is enabled if (this._enableSessionManagement) { const sessionId = req.headers["mcp-session-id"]; @@ -132,7 +137,7 @@ export class StreamableHTTPServerTransport implements Transport { if (req.method === "GET") { await this.handleGetRequest(req, res); } else if (req.method === "POST") { - await this.handlePostRequest(req, res); + await this.handlePostRequest(req, res, parsedBody); } else if (req.method === "DELETE") { await this.handleDeleteRequest(req, res); } else { @@ -213,7 +218,7 @@ export class StreamableHTTPServerTransport implements Transport { /** * Handles POST requests containing JSON-RPC messages */ - private async handlePostRequest(req: IncomingMessage, res: ServerResponse): Promise { + private async handlePostRequest(req: IncomingMessage, res: ServerResponse, parsedBody?: unknown): Promise { try { // validate the Accept header const acceptHeader = req.headers.accept; @@ -243,13 +248,18 @@ export class StreamableHTTPServerTransport implements Transport { return; } - const parsedCt = contentType.parse(ct); - const body = await getRawBody(req, { - limit: MAXIMUM_MESSAGE_SIZE, - encoding: parsedCt.parameters.charset ?? "utf-8", - }); + let rawMessage; + if (parsedBody !== undefined) { + rawMessage = parsedBody; + } else { + const parsedCt = contentType.parse(ct); + const body = await getRawBody(req, { + limit: MAXIMUM_MESSAGE_SIZE, + encoding: parsedCt.parameters.charset ?? "utf-8", + }); + rawMessage = JSON.parse(body.toString()); + } - const rawMessage = JSON.parse(body.toString()); let messages: JSONRPCMessage[]; // handle batch and single messages From e9caa5a12f8cba52b734102cd3cd28836f6cc113 Mon Sep 17 00:00:00 2001 From: Geng Yan Date: Thu, 3 Apr 2025 08:44:29 +0800 Subject: [PATCH 05/17] =?UTF-8?q?=E2=9C=A8=20feat:=20add=20customHeaders?= =?UTF-8?q?=20options?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/server/streamable-http.test.ts | 143 ++++++++++++++++++++++++++++- src/server/streamable-http.ts | 19 +++- 2 files changed, 155 insertions(+), 7 deletions(-) diff --git a/src/server/streamable-http.test.ts b/src/server/streamable-http.test.ts index 13f70b59..74b92c1a 100644 --- a/src/server/streamable-http.test.ts +++ b/src/server/streamable-http.test.ts @@ -98,7 +98,7 @@ describe("StreamableHTTPServerTransport", () => { await transport.handleRequest(req, mockResponse); - expect(mockResponse.writeHead).toHaveBeenCalledWith(404); + expect(mockResponse.writeHead).toHaveBeenCalledWith(404, {}); // check if the error response is a valid JSON-RPC error format expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"error"')); @@ -115,7 +115,7 @@ describe("StreamableHTTPServerTransport", () => { await transport.handleRequest(req, mockResponse); - expect(mockResponse.writeHead).toHaveBeenCalledWith(400); + expect(mockResponse.writeHead).toHaveBeenCalledWith(400, {}); expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"message":"Bad Request: Mcp-Session-Id header is required"')); }); @@ -342,7 +342,7 @@ describe("StreamableHTTPServerTransport", () => { await transport.handleRequest(req, mockResponse); - expect(mockResponse.writeHead).toHaveBeenCalledWith(406); + expect(mockResponse.writeHead).toHaveBeenCalledWith(406, {}); expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); }); @@ -788,4 +788,141 @@ describe("StreamableHTTPServerTransport", () => { expect(onMessageMock).not.toHaveBeenCalledWith(requestBodyMessage); }); }); + + describe("Custom Headers", () => { + const customHeaders = { + "X-Custom-Header": "custom-value", + "X-API-Version": "1.0", + "Access-Control-Allow-Origin": "*" + }; + + let transportWithHeaders: StreamableHTTPServerTransport; + let mockResponse: jest.Mocked; + + beforeEach(() => { + transportWithHeaders = new StreamableHTTPServerTransport(endpoint, { customHeaders }); + mockResponse = createMockResponse(); + }); + + it("should include custom headers in SSE response", async () => { + const req = createMockRequest({ + method: "GET", + headers: { + accept: "text/event-stream", + "mcp-session-id": transportWithHeaders.sessionId + }, + }); + + await transportWithHeaders.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + ...customHeaders, + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "mcp-session-id": transportWithHeaders.sessionId + }) + ); + }); + + it("should include custom headers in JSON response", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: 1, + }; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json", + "mcp-session-id": transportWithHeaders.sessionId + }, + body: JSON.stringify(message), + }); + + await transportWithHeaders.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + ...customHeaders, + "Content-Type": "application/json", + "mcp-session-id": transportWithHeaders.sessionId + }) + ); + }); + + it("should include custom headers in error responses", async () => { + const req = createMockRequest({ + method: "GET", + headers: { + accept: "text/event-stream", + "mcp-session-id": "invalid-session-id" + }, + }); + + await transportWithHeaders.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 404, + expect.objectContaining(customHeaders) + ); + }); + + it("should not override essential headers with custom headers", async () => { + const transportWithConflictingHeaders = new StreamableHTTPServerTransport(endpoint, { + customHeaders: { + "Content-Type": "text/plain", // 尝试覆盖必要的 Content-Type 头 + "X-Custom-Header": "custom-value" + } + }); + + const req = createMockRequest({ + method: "GET", + headers: { + accept: "text/event-stream", + "mcp-session-id": transportWithConflictingHeaders.sessionId + }, + }); + + await transportWithConflictingHeaders.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "Content-Type": "text/event-stream", // 应该保持原有的 Content-Type + "X-Custom-Header": "custom-value" + }) + ); + }); + + it("should work with empty custom headers", async () => { + const transportWithoutHeaders = new StreamableHTTPServerTransport(endpoint); + + const req = createMockRequest({ + method: "GET", + headers: { + accept: "text/event-stream", + "mcp-session-id": transportWithoutHeaders.sessionId + }, + }); + + await transportWithoutHeaders.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "mcp-session-id": transportWithoutHeaders.sessionId + }) + ); + }); + }); }); \ No newline at end of file diff --git a/src/server/streamable-http.ts b/src/server/streamable-http.ts index bec0a1ad..13d28968 100644 --- a/src/server/streamable-http.ts +++ b/src/server/streamable-http.ts @@ -28,6 +28,12 @@ export interface StreamableHTTPServerTransportOptions { * @default true */ enableSessionManagement?: boolean; + + /** + * Custom headers to be included in all responses + * These headers will be added to both SSE and regular HTTP responses + */ + customHeaders?: Record; } /** @@ -72,6 +78,7 @@ export class StreamableHTTPServerTransport implements Transport { private _started: boolean = false; private _requestConnections: Map = new Map(); // request ID to connection ID mapping private _enableSessionManagement: boolean; + private _customHeaders: Record; onclose?: () => void; onerror?: (error: Error) => void; @@ -80,6 +87,7 @@ export class StreamableHTTPServerTransport implements Transport { constructor(private _endpoint: string, options?: StreamableHTTPServerTransportOptions) { this._sessionId = randomUUID(); this._enableSessionManagement = options?.enableSessionManagement !== false; + this._customHeaders = options?.customHeaders || {}; } /** @@ -111,7 +119,7 @@ export class StreamableHTTPServerTransport implements Transport { // Continue processing normally } else if (!sessionId) { // Non-initialization requests without a session ID should return 400 Bad Request - res.writeHead(400).end(JSON.stringify({ + res.writeHead(400, this._customHeaders).end(JSON.stringify({ jsonrpc: "2.0", error: { code: -32000, @@ -122,7 +130,7 @@ export class StreamableHTTPServerTransport implements Transport { return; } else if ((Array.isArray(sessionId) ? sessionId[0] : sessionId) !== this._sessionId) { // Reject requests with invalid session ID with 404 Not Found - res.writeHead(404).end(JSON.stringify({ + res.writeHead(404, this._customHeaders).end(JSON.stringify({ jsonrpc: "2.0", error: { code: -32001, @@ -141,7 +149,7 @@ export class StreamableHTTPServerTransport implements Transport { } else if (req.method === "DELETE") { await this.handleDeleteRequest(req, res); } else { - res.writeHead(405).end(JSON.stringify({ + res.writeHead(405, this._customHeaders).end(JSON.stringify({ jsonrpc: "2.0", error: { code: -32000, @@ -159,7 +167,7 @@ export class StreamableHTTPServerTransport implements Transport { // validate the Accept header const acceptHeader = req.headers.accept; if (!acceptHeader || !acceptHeader.includes("text/event-stream")) { - res.writeHead(406).end(JSON.stringify({ + res.writeHead(406, this._customHeaders).end(JSON.stringify({ jsonrpc: "2.0", error: { code: -32000, @@ -176,6 +184,7 @@ export class StreamableHTTPServerTransport implements Transport { // Prepare response headers const headers: Record = { + ...this._customHeaders, "Content-Type": "text/event-stream", "Cache-Control": "no-cache", Connection: "keep-alive", @@ -294,6 +303,7 @@ export class StreamableHTTPServerTransport implements Transport { if (useSSE) { const headers: Record = { + ...this._customHeaders, "Content-Type": "text/event-stream", "Cache-Control": "no-cache", Connection: "keep-alive", @@ -338,6 +348,7 @@ export class StreamableHTTPServerTransport implements Transport { } else { // use direct JSON response const headers: Record = { + ...this._customHeaders, "Content-Type": "application/json", }; From bafd9e74f3b681c0bff38664c4399d55a4d258c0 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sat, 5 Apr 2025 19:52:29 +0100 Subject: [PATCH 06/17] Add ways to associate related requests and notifications --- src/server/mcp.test.ts | 8 +++-- src/server/mcp.ts | 16 +++++---- src/shared/protocol.ts | 80 ++++++++++++++++++++++++++++++----------- src/shared/transport.ts | 6 ++-- 4 files changed, 77 insertions(+), 33 deletions(-) diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index 2e91a568..ae11279b 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -85,6 +85,8 @@ describe("ResourceTemplate", () => { const abortController = new AbortController(); const result = await template.listCallback?.({ signal: abortController.signal, + sendRequest: () => { throw new Error("Not implemented") }, + sendNotification: () => { throw new Error("Not implemented") } }); expect(result?.resources).toHaveLength(1); expect(list).toHaveBeenCalled(); @@ -318,7 +320,7 @@ describe("tool()", () => { // This should succeed mcpServer.tool("tool1", () => ({ content: [] })); - + // This should also succeed and not throw about request handlers mcpServer.tool("tool2", () => ({ content: [] })); }); @@ -815,7 +817,7 @@ describe("resource()", () => { }, ], })); - + // This should also succeed and not throw about request handlers mcpServer.resource("resource2", "test://resource2", async () => ({ contents: [ @@ -1321,7 +1323,7 @@ describe("prompt()", () => { }, ], })); - + // This should also succeed and not throw about request handlers mcpServer.prompt("prompt2", async () => ({ messages: [ diff --git a/src/server/mcp.ts b/src/server/mcp.ts index 8f4a909c..484084fc 100644 --- a/src/server/mcp.ts +++ b/src/server/mcp.ts @@ -37,6 +37,8 @@ import { PromptArgument, GetPromptResult, ReadResourceResult, + ServerRequest, + ServerNotification, } from "../types.js"; import { Completable, CompletableDef } from "./completable.js"; import { UriTemplate, Variables } from "../shared/uriTemplate.js"; @@ -694,9 +696,9 @@ export type ToolCallback = Args extends ZodRawShape ? ( args: z.objectOutputType, - extra: RequestHandlerExtra, + extra: RequestHandlerExtra, ) => CallToolResult | Promise - : (extra: RequestHandlerExtra) => CallToolResult | Promise; + : (extra: RequestHandlerExtra) => CallToolResult | Promise; type RegisteredTool = { description?: string; @@ -717,7 +719,7 @@ export type ResourceMetadata = Omit; * Callback to list all resources matching a given template. */ export type ListResourcesCallback = ( - extra: RequestHandlerExtra, + extra: RequestHandlerExtra, ) => ListResourcesResult | Promise; /** @@ -725,7 +727,7 @@ export type ListResourcesCallback = ( */ export type ReadResourceCallback = ( uri: URL, - extra: RequestHandlerExtra, + extra: RequestHandlerExtra, ) => ReadResourceResult | Promise; type RegisteredResource = { @@ -740,7 +742,7 @@ type RegisteredResource = { export type ReadResourceTemplateCallback = ( uri: URL, variables: Variables, - extra: RequestHandlerExtra, + extra: RequestHandlerExtra, ) => ReadResourceResult | Promise; type RegisteredResourceTemplate = { @@ -760,9 +762,9 @@ export type PromptCallback< > = Args extends PromptArgsRawShape ? ( args: z.objectOutputType, - extra: RequestHandlerExtra, + extra: RequestHandlerExtra, ) => GetPromptResult | Promise - : (extra: RequestHandlerExtra) => GetPromptResult | Promise; + : (extra: RequestHandlerExtra) => GetPromptResult | Promise; type RegisteredPrompt = { description?: string; diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index a6e47184..b072e578 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -78,22 +78,52 @@ export type RequestOptions = { * If not specified, there is no maximum total timeout. */ maxTotalTimeout?: number; + + /** + * May be used to indicate to the transport which incoming request to associate this outgoing request with. + */ + relatedRequestId?: RequestId; }; /** - * Extra data given to request handlers. + * Options that can be given per notification. */ -export type RequestHandlerExtra = { +export type NotificationOptions = { /** - * An abort signal used to communicate if the request was cancelled from the sender's side. + * May be used to indicate to the transport which incoming request to associate this outgoing notification with. */ - signal: AbortSignal; + relatedRequestId?: RequestId; +} - /** - * The session ID from the transport, if available. - */ - sessionId?: string; -}; +/** + * Extra data given to request handlers. + */ +export type RequestHandlerExtra = { + /** + * An abort signal used to communicate if the request was cancelled from the sender's side. + */ + signal: AbortSignal; + + /** + * The session ID from the transport, if available. + */ + sessionId?: string; + + /** + * Sends a notification that relates to the current request being handled. + * + * This is used by certain transports to correctly associate related messages. + */ + sendNotification: (notification: SendNotificationT) => Promise; + + /** + * Sends a request that relates to the current request being handled. + * + * This is used by certain transports to correctly associate related messages. + */ + sendRequest: >(request: SendRequestT, resultSchema: U, options?: RequestOptions) => Promise>; + }; /** * Information about a request's timeout state @@ -122,7 +152,7 @@ export abstract class Protocol< string, ( request: JSONRPCRequest, - extra: RequestHandlerExtra, + extra: RequestHandlerExtra, ) => Promise > = new Map(); private _requestHandlerAbortControllers: Map = @@ -316,9 +346,14 @@ export abstract class Protocol< this._requestHandlerAbortControllers.set(request.id, abortController); // Create extra object with both abort signal and sessionId from transport - const extra: RequestHandlerExtra = { + const extra: RequestHandlerExtra = { signal: abortController.signal, sessionId: this._transport?.sessionId, + sendNotification: + (notification) => + this.notification(notification, { relatedRequestId: request.id }), + sendRequest: (r, resultSchema, options?) => + this.request(r, resultSchema, { ...options, relatedRequestId: request.id }) }; // Starting with Promise.resolve() puts any synchronous errors into the monad as well. @@ -364,7 +399,7 @@ export abstract class Protocol< private _onprogress(notification: ProgressNotification): void { const { progressToken, ...params } = notification.params; const messageId = Number(progressToken); - + const handler = this._progressHandlers.get(messageId); if (!handler) { this._onerror(new Error(`Received a progress notification for an unknown token: ${JSON.stringify(notification)}`)); @@ -373,7 +408,7 @@ export abstract class Protocol< const responseHandler = this._responseHandlers.get(messageId); const timeoutInfo = this._timeoutInfo.get(messageId); - + if (timeoutInfo && responseHandler && timeoutInfo.resetTimeoutOnProgress) { try { this._resetTimeout(messageId); @@ -460,6 +495,8 @@ export abstract class Protocol< resultSchema: T, options?: RequestOptions, ): Promise> { + const { relatedRequestId } = options ?? {}; + return new Promise((resolve, reject) => { if (!this._transport) { reject(new Error("Not connected")); @@ -500,7 +537,7 @@ export abstract class Protocol< requestId: messageId, reason: String(reason), }, - }) + }, { relatedRequestId }) .catch((error) => this._onerror(new Error(`Failed to send cancellation: ${error}`)), ); @@ -538,7 +575,7 @@ export abstract class Protocol< this._setupTimeout(messageId, timeout, options?.maxTotalTimeout, timeoutHandler, options?.resetTimeoutOnProgress ?? false); - this._transport.send(jsonrpcRequest).catch((error) => { + this._transport.send(jsonrpcRequest, { relatedRequestId }).catch((error) => { this._cleanupTimeout(messageId); reject(error); }); @@ -548,7 +585,7 @@ export abstract class Protocol< /** * Emits a notification, which is a one-way message that does not expect a response. */ - async notification(notification: SendNotificationT): Promise { + async notification(notification: SendNotificationT, options?: NotificationOptions): Promise { if (!this._transport) { throw new Error("Not connected"); } @@ -560,7 +597,7 @@ export abstract class Protocol< jsonrpc: "2.0", }; - await this._transport.send(jsonrpcNotification); + await this._transport.send(jsonrpcNotification, options); } /** @@ -576,14 +613,15 @@ export abstract class Protocol< requestSchema: T, handler: ( request: z.infer, - extra: RequestHandlerExtra, + extra: RequestHandlerExtra, ) => SendResultT | Promise, ): void { const method = requestSchema.shape.method.value; this.assertRequestHandlerCapability(method); - this._requestHandlers.set(method, (request, extra) => - Promise.resolve(handler(requestSchema.parse(request), extra)), - ); + + this._requestHandlers.set(method, (request, extra) => { + return Promise.resolve(handler(requestSchema.parse(request), extra)); + }); } /** diff --git a/src/shared/transport.ts b/src/shared/transport.ts index b80e2a51..84d2c829 100644 --- a/src/shared/transport.ts +++ b/src/shared/transport.ts @@ -1,4 +1,4 @@ -import { JSONRPCMessage } from "../types.js"; +import { JSONRPCMessage, RequestId } from "../types.js"; /** * Describes the minimal contract for a MCP transport that a client or server can communicate over. @@ -15,8 +15,10 @@ export interface Transport { /** * Sends a JSON-RPC message (request or response). + * + * If present, `relatedRequestId` is used to indicate to the transport which incoming request to associate this outgoing message with. */ - send(message: JSONRPCMessage): Promise; + send(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId }): Promise; /** * Closes the connection. From b0697195a78a1f7b00f4e37f0e053985f5b4ec1d Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sun, 6 Apr 2025 16:09:15 +0100 Subject: [PATCH 07/17] add test to cover nested logging withing a tool call --- src/server/mcp.test.ts | 65 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index ae11279b..0c136e29 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -11,6 +11,7 @@ import { ListPromptsResultSchema, GetPromptResultSchema, CompleteResultSchema, + LoggingMessageNotificationSchema, } from "../types.js"; import { ResourceTemplate } from "./mcp.js"; import { completable } from "./completable.js"; @@ -378,6 +379,70 @@ describe("tool()", () => { expect(receivedSessionId).toBe("test-session-123"); }); + test("should provide sendNotification withing tool call", async () => { + const mcpServer = new McpServer( + { + name: "test server", + version: "1.0", + }, + { capabilities: { logging: {} } }, + ); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + tools: {}, + }, + }, + ); + + let receivedLogMessage: string | undefined; + + const loggingMessage = "hello here is log message 1" + + client.setNotificationHandler(LoggingMessageNotificationSchema, (notification) => { + receivedLogMessage = notification.params.data as string; + + }); + + mcpServer.tool("test-tool", async ({ sendNotification }) => { + await sendNotification({ method: "notifications/message", params: { level: "debug", data: loggingMessage } }); + return { + content: [ + { + type: "text", + text: "Test response", + }, + ], + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + // Set a test sessionId on the server transport + serverTransport.sessionId = "test-session-123"; + + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + await client.request( + { + method: "tools/call", + params: { + name: "test-tool", + }, + }, + CallToolResultSchema, + ); + expect(receivedLogMessage).toBe(loggingMessage); + }); + test("should allow client to call server tools", async () => { const mcpServer = new McpServer({ name: "test server", From e3bb99cd5c9f25e079f678b3d82c05765d02e12f Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sun, 6 Apr 2025 17:27:00 +0100 Subject: [PATCH 08/17] remove the use of enableSessionManagement --- src/server/streamable-http.test.ts | 141 ++++++++++++----------------- src/server/streamable-http.ts | 90 +++++++++--------- 2 files changed, 104 insertions(+), 127 deletions(-) diff --git a/src/server/streamable-http.test.ts b/src/server/streamable-http.test.ts index 74b92c1a..b888b0ec 100644 --- a/src/server/streamable-http.test.ts +++ b/src/server/streamable-http.test.ts @@ -2,7 +2,7 @@ import { IncomingMessage, ServerResponse } from "node:http"; import { StreamableHTTPServerTransport } from "./streamable-http.js"; import { JSONRPCMessage } from "../types.js"; import { Readable } from "node:stream"; - +import { randomUUID } from "node:crypto"; // Mock IncomingMessage function createMockRequest(options: { method: string; @@ -10,7 +10,7 @@ function createMockRequest(options: { body?: string; }): IncomingMessage { const readable = new Readable(); - readable._read = () => {}; + readable._read = () => { }; if (options.body) { readable.push(options.body); readable.push(null); @@ -37,12 +37,13 @@ function createMockResponse(): jest.Mocked { } describe("StreamableHTTPServerTransport", () => { - const endpoint = "/mcp"; let transport: StreamableHTTPServerTransport; let mockResponse: jest.Mocked; beforeEach(() => { - transport = new StreamableHTTPServerTransport(endpoint); + transport = new StreamableHTTPServerTransport({ + sessionId: randomUUID(), + }); mockResponse = createMockResponse(); }); @@ -61,7 +62,7 @@ describe("StreamableHTTPServerTransport", () => { const initializeMessage: JSONRPCMessage = { jsonrpc: "2.0", method: "initialize", - params: { + params: { clientInfo: { name: "test-client", version: "1.0" }, protocolVersion: "2025-03-26" }, @@ -119,49 +120,13 @@ describe("StreamableHTTPServerTransport", () => { expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"message":"Bad Request: Mcp-Session-Id header is required"')); }); - - it("should always include session ID in initialization response even in stateless mode", async () => { - // Create a stateless transport for this test - const statelessTransport = new StreamableHTTPServerTransport(endpoint, { enableSessionManagement: false }); - - // Create an initialization request - const initializeMessage: JSONRPCMessage = { - jsonrpc: "2.0", - method: "initialize", - params: { - clientInfo: { name: "test-client", version: "1.0" }, - protocolVersion: "2025-03-26" - }, - id: "init-1", - }; - - const req = createMockRequest({ - method: "POST", - headers: { - "content-type": "application/json", - "accept": "application/json", - }, - body: JSON.stringify(initializeMessage), - }); - - await statelessTransport.handleRequest(req, mockResponse); - - // In stateless mode, session ID should also be included for initialize responses - expect(mockResponse.writeHead).toHaveBeenCalledWith( - 200, - expect.objectContaining({ - "mcp-session-id": statelessTransport.sessionId, - }) - ); - }); }); - describe("Stateless Mode", () => { let statelessTransport: StreamableHTTPServerTransport; let mockResponse: jest.Mocked; beforeEach(() => { - statelessTransport = new StreamableHTTPServerTransport(endpoint, { enableSessionManagement: false }); + statelessTransport = new StreamableHTTPServerTransport({ sessionId: undefined }); mockResponse = createMockResponse(); }); @@ -268,7 +233,7 @@ describe("StreamableHTTPServerTransport", () => { }); await statelessTransport.handleRequest(req2, mockResponse); - + // Should still succeed expect(mockResponse.writeHead).toHaveBeenCalledWith( 200, @@ -278,12 +243,12 @@ describe("StreamableHTTPServerTransport", () => { ); }); - it("should handle initialization requests properly in both modes", async () => { + it("should handle initialization requests properly in statefull mode", async () => { // Initialize message that would typically be sent during initialization const initializeMessage: JSONRPCMessage = { jsonrpc: "2.0", method: "initialize", - params: { + params: { clientInfo: { name: "test-client", version: "1.0" }, protocolVersion: "2025-03-26" }, @@ -301,7 +266,7 @@ describe("StreamableHTTPServerTransport", () => { }); await transport.handleRequest(statefulReq, mockResponse); - + // In stateful mode, session ID should be included in the response header expect(mockResponse.writeHead).toHaveBeenCalledWith( 200, @@ -309,9 +274,19 @@ describe("StreamableHTTPServerTransport", () => { "mcp-session-id": transport.sessionId, }) ); + }); - // Reset mocks for stateless test - mockResponse.writeHead.mockClear(); + it("should handle initialization requests properly in stateless mode", async () => { + // Initialize message that would typically be sent during initialization + const initializeMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; // Test stateless transport const statelessReq = createMockRequest({ @@ -324,10 +299,11 @@ describe("StreamableHTTPServerTransport", () => { }); await statelessTransport.handleRequest(statelessReq, mockResponse); - + // In stateless mode, session ID should also be included for initialize responses const headers = mockResponse.writeHead.mock.calls[0][1]; - expect(headers).toHaveProperty("mcp-session-id", statelessTransport.sessionId); + expect(headers).not.toHaveProperty("mcp-session-id"); + }); }); @@ -519,14 +495,14 @@ describe("StreamableHTTPServerTransport", () => { // Send a message to first connection const message1: JSONRPCMessage = { - jsonrpc: "2.0", - method: "test1", - params: {}, + jsonrpc: "2.0", + method: "test1", + params: {}, id: 1 }; - + await transport.send(message1); - + // Get message ID (captured from write call) const writeCall = mockResponse.write.mock.calls[0][0] as string; const idMatch = writeCall.match(/id: ([a-f0-9-]+)/); @@ -550,12 +526,12 @@ describe("StreamableHTTPServerTransport", () => { // Send a second message const message2: JSONRPCMessage = { - jsonrpc: "2.0", - method: "test2", - params: {}, + jsonrpc: "2.0", + method: "test2", + params: {}, id: 2 }; - + await transport.send(message2); // Verify the second message was received by both connections @@ -596,7 +572,7 @@ describe("StreamableHTTPServerTransport", () => { params: {}, id: "test-id", }; - + const reqPost = createMockRequest({ method: "POST", headers: { @@ -605,28 +581,28 @@ describe("StreamableHTTPServerTransport", () => { }, body: JSON.stringify(requestMessage), }); - + await transport.handleRequest(reqPost, mockResponse1); - + // Send a response with matching ID const responseMessage: JSONRPCMessage = { jsonrpc: "2.0", result: { success: true }, id: "test-id", }; - + await transport.send(responseMessage); - + // Verify response was sent to the right connection expect(mockResponse1.write).toHaveBeenCalledWith( expect.stringContaining(JSON.stringify(responseMessage)) ); - + // Check if write was called with this exact message on the second connection - const writeCallsOnSecondConn = mockResponse2.write.mock.calls.filter(call => + const writeCallsOnSecondConn = mockResponse2.write.mock.calls.filter(call => typeof call[0] === 'string' && call[0].includes(JSON.stringify(responseMessage)) ); - + // Verify the response wasn't broadcast to all connections expect(writeCallsOnSecondConn.length).toBe(0); }); @@ -680,7 +656,7 @@ describe("StreamableHTTPServerTransport", () => { const message: JSONRPCMessage = { jsonrpc: "2.0", method: "initialize", - params: { + params: { clientInfo: { name: "test-client", version: "1.0" }, protocolVersion: "2025-03-26" }, @@ -715,17 +691,17 @@ describe("StreamableHTTPServerTransport", () => { it("should handle pre-parsed batch messages", async () => { const batchMessages: JSONRPCMessage[] = [ - { - jsonrpc: "2.0", - method: "method1", + { + jsonrpc: "2.0", + method: "method1", params: { data: "test1" }, - id: "batch1" + id: "batch1" }, - { - jsonrpc: "2.0", - method: "method2", + { + jsonrpc: "2.0", + method: "method2", params: { data: "test2" }, - id: "batch2" + id: "batch2" }, ]; @@ -800,7 +776,7 @@ describe("StreamableHTTPServerTransport", () => { let mockResponse: jest.Mocked; beforeEach(() => { - transportWithHeaders = new StreamableHTTPServerTransport(endpoint, { customHeaders }); + transportWithHeaders = new StreamableHTTPServerTransport({ sessionId: randomUUID(), customHeaders }); mockResponse = createMockResponse(); }); @@ -875,9 +851,10 @@ describe("StreamableHTTPServerTransport", () => { }); it("should not override essential headers with custom headers", async () => { - const transportWithConflictingHeaders = new StreamableHTTPServerTransport(endpoint, { + const transportWithConflictingHeaders = new StreamableHTTPServerTransport({ + sessionId: randomUUID(), customHeaders: { - "Content-Type": "text/plain", // 尝试覆盖必要的 Content-Type 头 + "Content-Type": "text/plain", "X-Custom-Header": "custom-value" } }); @@ -902,8 +879,10 @@ describe("StreamableHTTPServerTransport", () => { }); it("should work with empty custom headers", async () => { - const transportWithoutHeaders = new StreamableHTTPServerTransport(endpoint); - + const transportWithoutHeaders = new StreamableHTTPServerTransport({ + sessionId: randomUUID(), + }); + const req = createMockRequest({ method: "GET", headers: { diff --git a/src/server/streamable-http.ts b/src/server/streamable-http.ts index 13d28968..ef654f85 100644 --- a/src/server/streamable-http.ts +++ b/src/server/streamable-http.ts @@ -23,11 +23,11 @@ interface StreamConnection { */ export interface StreamableHTTPServerTransportOptions { /** - * Whether to enable session management through mcp-session-id headers - * When set to false, the transport operates in stateless mode without session validation - * @default true + * The session ID SHOULD be globally unique and cryptographically secure (e.g., a securely generated UUID, a JWT, or a cryptographic hash) + * + * When sessionId is not set, the transport will be in stateless mode. */ - enableSessionManagement?: boolean; + sessionId: string | undefined; /** * Custom headers to be included in all responses @@ -38,17 +38,19 @@ export interface StreamableHTTPServerTransportOptions { /** * Server transport for Streamable HTTP: this implements the MCP Streamable HTTP transport specification. - * It supports both SSE streaming and direct HTTP responses, with session management and message resumability. + * It supports both SSE streaming and direct HTTP responses. * * Usage example: * * ```typescript - * // Stateful mode (default) - with session management - * const statefulTransport = new StreamableHTTPServerTransport("/mcp"); + * // Stateful mode - server sets the session ID + * const statefulTransport = new StreamableHTTPServerTransport({ + * sessionId: randomUUID(), + * }); * - * // Stateless mode - without session management - * const statelessTransport = new StreamableHTTPServerTransport("/mcp", { - * enableSessionManagement: false + * // Stateless mode - explisitly set session ID to undefined + * const statelessTransport = new StreamableHTTPServerTransport({ + * sessionId: undefined, * }); * * // Using with pre-parsed request body @@ -70,23 +72,22 @@ export interface StreamableHTTPServerTransportOptions { */ export class StreamableHTTPServerTransport implements Transport { private _connections: Map = new Map(); - private _sessionId: string; + // when sessionID is not set, it means the transport is in stateless mode + private _sessionId: string | undefined; private _messageHistory: Map = new Map(); private _started: boolean = false; private _requestConnections: Map = new Map(); // request ID to connection ID mapping - private _enableSessionManagement: boolean; private _customHeaders: Record; onclose?: () => void; onerror?: (error: Error) => void; onmessage?: (message: JSONRPCMessage) => void; - constructor(private _endpoint: string, options?: StreamableHTTPServerTransportOptions) { - this._sessionId = randomUUID(); - this._enableSessionManagement = options?.enableSessionManagement !== false; + constructor(options: StreamableHTTPServerTransportOptions) { + this._sessionId = options?.sessionId; this._customHeaders = options?.customHeaders || {}; } @@ -106,13 +107,13 @@ export class StreamableHTTPServerTransport implements Transport { */ async handleRequest(req: IncomingMessage, res: ServerResponse, parsedBody?: unknown): Promise { // Only validate session ID for non-initialization requests when session management is enabled - if (this._enableSessionManagement) { + if (this._sessionId !== undefined) { const sessionId = req.headers["mcp-session-id"]; - + // Check if this might be an initialization request - const isInitializationRequest = req.method === "POST" && + const isInitializationRequest = req.method === "POST" && req.headers["content-type"]?.includes("application/json"); - + if (isInitializationRequest) { // For POST requests with JSON content, we need to check if it's an initialization request // This will be done in handlePostRequest, as we need to parse the body @@ -189,9 +190,9 @@ export class StreamableHTTPServerTransport implements Transport { "Cache-Control": "no-cache", Connection: "keep-alive", }; - - // Only include session ID header if session management is enabled - if (this._enableSessionManagement) { + + // Only include session ID header if session management is enabled by assigning a value to _sessionId + if (this._sessionId !== undefined) { headers["mcp-session-id"] = this._sessionId; } @@ -231,8 +232,8 @@ export class StreamableHTTPServerTransport implements Transport { try { // validate the Accept header const acceptHeader = req.headers.accept; - if (!acceptHeader || - (!acceptHeader.includes("application/json") && !acceptHeader.includes("text/event-stream"))) { + if (!acceptHeader || + (!acceptHeader.includes("application/json") && !acceptHeader.includes("text/event-stream"))) { res.writeHead(406).end(JSON.stringify({ jsonrpc: "2.0", error: { @@ -270,7 +271,7 @@ export class StreamableHTTPServerTransport implements Transport { } let messages: JSONRPCMessage[]; - + // handle batch and single messages if (Array.isArray(rawMessage)) { messages = rawMessage.map(msg => JSONRPCMessageSchema.parse(msg)); @@ -286,7 +287,7 @@ export class StreamableHTTPServerTransport implements Transport { // check if it contains requests const hasRequests = messages.some(msg => 'method' in msg && 'id' in msg); - const hasOnlyNotificationsOrResponses = messages.every(msg => + const hasOnlyNotificationsOrResponses = messages.every(msg => ('method' in msg && !('id' in msg)) || ('result' in msg || 'error' in msg)); if (hasOnlyNotificationsOrResponses) { @@ -300,7 +301,7 @@ export class StreamableHTTPServerTransport implements Transport { } else if (hasRequests) { // if it contains requests, you can choose to return an SSE stream or a JSON response const useSSE = acceptHeader.includes("text/event-stream"); - + if (useSSE) { const headers: Record = { ...this._customHeaders, @@ -308,10 +309,8 @@ export class StreamableHTTPServerTransport implements Transport { "Cache-Control": "no-cache", Connection: "keep-alive", }; - - // Only include session ID header if session management is enabled - // Always include session ID for initialization requests - if (this._enableSessionManagement || isInitializationRequest) { + + if (this._sessionId !== undefined) { headers["mcp-session-id"] = this._sessionId; } @@ -351,20 +350,19 @@ export class StreamableHTTPServerTransport implements Transport { ...this._customHeaders, "Content-Type": "application/json", }; - - // Only include session ID header if session management is enabled - // Always include session ID for initialization requests - if (this._enableSessionManagement || isInitializationRequest) { + + + if (this._sessionId !== undefined) { headers["mcp-session-id"] = this._sessionId; } res.writeHead(200, headers); - + // handle each message for (const message of messages) { this.onmessage?.(message); } - + res.end(); } } @@ -396,11 +394,11 @@ export class StreamableHTTPServerTransport implements Transport { */ private replayMessages(connectionId: string, lastEventId: string): void { if (!lastEventId) return; - + // only replay messages that should be sent on this connection const messages = Array.from(this._messageHistory.entries()) - .filter(([id, { connectionId: msgConnId }]) => - id > lastEventId && + .filter(([id, { connectionId: msgConnId }]) => + id > lastEventId && (!msgConnId || msgConnId === connectionId)) // only replay messages that are not specified to a connection or specified to the current connection .sort(([a], [b]) => a.localeCompare(b)); @@ -430,11 +428,11 @@ export class StreamableHTTPServerTransport implements Transport { } let targetConnectionId = ""; - + // if it is a response, find the corresponding request connection if ('id' in message && ('result' in message || 'error' in message)) { const connId = this._requestConnections.get(String(message.id)); - + // if the corresponding connection is not found, the connection may be disconnected if (!connId || !this._connections.has(connId)) { // select an available connection @@ -458,9 +456,9 @@ export class StreamableHTTPServerTransport implements Transport { } const messageId = randomUUID(); - this._messageHistory.set(messageId, { - message, - connectionId: targetConnectionId + this._messageHistory.set(messageId, { + message, + connectionId: targetConnectionId }); // keep the message history in a reasonable range @@ -490,7 +488,7 @@ export class StreamableHTTPServerTransport implements Transport { /** * Returns the session ID for this transport */ - get sessionId(): string { + get sessionId(): string | undefined { return this._sessionId; } } \ No newline at end of file From 00135707f6afb9c42fe406856e6c592d2a006922 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sun, 6 Apr 2025 17:28:46 +0100 Subject: [PATCH 09/17] rename files to use conventioanal typescript names --- src/server/{streamable-http.test.ts => streamableHttp.test.ts} | 2 +- src/server/{streamable-http.ts => streamableHttp.ts} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename src/server/{streamable-http.test.ts => streamableHttp.test.ts} (99%) rename src/server/{streamable-http.ts => streamableHttp.ts} (100%) diff --git a/src/server/streamable-http.test.ts b/src/server/streamableHttp.test.ts similarity index 99% rename from src/server/streamable-http.test.ts rename to src/server/streamableHttp.test.ts index b888b0ec..baf4d80a 100644 --- a/src/server/streamable-http.test.ts +++ b/src/server/streamableHttp.test.ts @@ -1,5 +1,5 @@ import { IncomingMessage, ServerResponse } from "node:http"; -import { StreamableHTTPServerTransport } from "./streamable-http.js"; +import { StreamableHTTPServerTransport } from "./streamableHttp.js"; import { JSONRPCMessage } from "../types.js"; import { Readable } from "node:stream"; import { randomUUID } from "node:crypto"; diff --git a/src/server/streamable-http.ts b/src/server/streamableHttp.ts similarity index 100% rename from src/server/streamable-http.ts rename to src/server/streamableHttp.ts From 268c9f7469ccd3c89e5e5c92cc5e39a7f10e6328 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sun, 6 Apr 2025 18:42:46 +0100 Subject: [PATCH 10/17] fix test --- src/client/index.test.ts | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/client/index.test.ts b/src/client/index.test.ts index 1209b60c..e153687c 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -66,6 +66,9 @@ test("should initialize with matching protocol version", async () => { protocolVersion: LATEST_PROTOCOL_VERSION, }), }), + expect.objectContaining({ + relatedRequestId: undefined, + }), ); // Should have the instructions returned From 6c1b9ba06a403f4498efa50b29a0cfa9db1f0f3f Mon Sep 17 00:00:00 2001 From: ihrpr Date: Mon, 7 Apr 2025 12:12:43 +0100 Subject: [PATCH 11/17] remove local storage for replaying and remove support for sse connection thought get --- src/server/streamableHttp.test.ts | 299 +++++++++++++++++++---------- src/server/streamableHttp.ts | 308 ++++++++++-------------------- src/shared/transport.ts | 2 +- 3 files changed, 297 insertions(+), 312 deletions(-) diff --git a/src/server/streamableHttp.test.ts b/src/server/streamableHttp.test.ts index baf4d80a..8fb55e79 100644 --- a/src/server/streamableHttp.test.ts +++ b/src/server/streamableHttp.test.ts @@ -52,7 +52,7 @@ describe("StreamableHTTPServerTransport", () => { }); describe("Session Management", () => { - it("should generate a valid session ID", () => { + it("should store a valid session ID", () => { expect(transport.sessionId).toBeTruthy(); expect(typeof transport.sessionId).toBe("string"); }); @@ -158,11 +158,18 @@ describe("StreamableHTTPServerTransport", () => { it("should not validate session ID in stateless mode", async () => { const req = createMockRequest({ - method: "GET", + method: "POST", headers: { - accept: "text/event-stream", + "content-type": "application/json", + "accept": "application/json", "mcp-session-id": "invalid-session-id", // This would cause a 404 in stateful mode }, + body: JSON.stringify({ + jsonrpc: "2.0", + method: "test", + params: {}, + id: 1 + }), }); await statelessTransport.handleRequest(req, mockResponse); @@ -206,10 +213,17 @@ describe("StreamableHTTPServerTransport", () => { it("should work with a mix of requests with and without session IDs in stateless mode", async () => { // First request without session ID const req1 = createMockRequest({ - method: "GET", + method: "POST", headers: { + "content-type": "application/json", accept: "text/event-stream", }, + body: JSON.stringify({ + jsonrpc: "2.0", + method: "test", + params: {}, + id: "test-id" + }) }); await statelessTransport.handleRequest(req1, mockResponse); @@ -225,11 +239,18 @@ describe("StreamableHTTPServerTransport", () => { // Second request with a session ID (which would be invalid in stateful mode) const req2 = createMockRequest({ - method: "GET", + method: "POST", headers: { + "content-type": "application/json", accept: "text/event-stream", "mcp-session-id": "some-random-session-id", }, + body: JSON.stringify({ + jsonrpc: "2.0", + method: "test2", + params: {}, + id: "test-id-2" + }) }); await statelessTransport.handleRequest(req2, mockResponse); @@ -308,21 +329,7 @@ describe("StreamableHTTPServerTransport", () => { }); describe("Request Handling", () => { - it("should reject GET requests without Accept: text/event-stream header", async () => { - const req = createMockRequest({ - method: "GET", - headers: { - "mcp-session-id": transport.sessionId, - }, - }); - - await transport.handleRequest(req, mockResponse); - - expect(mockResponse.writeHead).toHaveBeenCalledWith(406, {}); - expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); - }); - - it("should properly handle GET requests with Accept header and establish SSE connection", async () => { + it("should reject GET requests for SSE with 405 Method Not Allowed", async () => { const req = createMockRequest({ method: "GET", headers: { @@ -333,14 +340,11 @@ describe("StreamableHTTPServerTransport", () => { await transport.handleRequest(req, mockResponse); - expect(mockResponse.writeHead).toHaveBeenCalledWith( - 200, - expect.objectContaining({ - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - Connection: "keep-alive", - }) - ); + expect(mockResponse.writeHead).toHaveBeenCalledWith(405, expect.objectContaining({ + "Allow": "POST, DELETE" + })); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('Server does not offer an SSE stream at this endpoint')); }); it("should reject POST requests without proper Accept header", async () => { @@ -421,7 +425,7 @@ describe("StreamableHTTPServerTransport", () => { expect(mockResponse.writeHead).toHaveBeenCalledWith(202); }); - it("should handle batch messages properly", async () => { + it("should handle batch notification messages properly with 202 response", async () => { const batchMessages: JSONRPCMessage[] = [ { jsonrpc: "2.0", method: "test1", params: {} }, { jsonrpc: "2.0", method: "test2", params: {} }, @@ -446,6 +450,39 @@ describe("StreamableHTTPServerTransport", () => { expect(mockResponse.writeHead).toHaveBeenCalledWith(202); }); + it("should handle batch request messages with SSE when Accept header includes text/event-stream", async () => { + const batchMessages: JSONRPCMessage[] = [ + { jsonrpc: "2.0", method: "test1", params: {}, id: "req1" }, + { jsonrpc: "2.0", method: "test2", params: {}, id: "req2" }, + ]; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "text/event-stream, application/json", + "mcp-session-id": transport.sessionId, + }, + body: JSON.stringify(batchMessages), + }); + + const onMessageMock = jest.fn(); + transport.onmessage = onMessageMock; + + await transport.handleRequest(req, mockResponse); + + // Should establish SSE connection + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "Content-Type": "text/event-stream" + }) + ); + expect(onMessageMock).toHaveBeenCalledTimes(2); + // Stream should remain open until responses are sent + expect(mockResponse.end).not.toHaveBeenCalled(); + }); + it("should reject unsupported Content-Type", async () => { const req = createMockRequest({ method: "POST", @@ -481,130 +518,170 @@ describe("StreamableHTTPServerTransport", () => { }); }); - describe("Message Replay", () => { - it("should replay messages after specified Last-Event-ID", async () => { - // Establish first connection with Accept header and session ID - const req1 = createMockRequest({ - method: "GET", + describe("SSE Response Handling", () => { + it("should send response messages as SSE events", async () => { + // Setup a POST request with JSON-RPC request that accepts SSE + const requestMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: "test-req-id" + }; + + const req = createMockRequest({ + method: "POST", headers: { + "content-type": "application/json", "accept": "text/event-stream", "mcp-session-id": transport.sessionId }, + body: JSON.stringify(requestMessage) }); - await transport.handleRequest(req1, mockResponse); - // Send a message to first connection - const message1: JSONRPCMessage = { + await transport.handleRequest(req, mockResponse); + + // Send a response to the request + const responseMessage: JSONRPCMessage = { jsonrpc: "2.0", - method: "test1", - params: {}, - id: 1 + result: { value: "test-result" }, + id: "test-req-id" }; - await transport.send(message1); + await transport.send(responseMessage, { relatedRequestId: "test-req-id" }); - // Get message ID (captured from write call) - const writeCall = mockResponse.write.mock.calls[0][0] as string; - const idMatch = writeCall.match(/id: ([a-f0-9-]+)/); - if (!idMatch) { - throw new Error("Message ID not found in write call"); - } - const lastEventId = idMatch[1]; + // Verify response was sent as SSE event + expect(mockResponse.write).toHaveBeenCalledWith( + expect.stringContaining(`event: message\ndata: ${JSON.stringify(responseMessage)}\n\n`) + ); - // Create a second connection with last-event-id - const mockResponse2 = createMockResponse(); - const req2 = createMockRequest({ - method: "GET", + // Stream should be closed after sending response + expect(mockResponse.end).toHaveBeenCalled(); + }); + + it("should keep stream open when sending intermediate notifications and requests", async () => { + // Setup a POST request with JSON-RPC request that accepts SSE + const requestMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: "test-req-id" + }; + + const req = createMockRequest({ + method: "POST", headers: { + "content-type": "application/json", "accept": "text/event-stream", - "last-event-id": lastEventId, "mcp-session-id": transport.sessionId }, + body: JSON.stringify(requestMessage) }); - await transport.handleRequest(req2, mockResponse2); + await transport.handleRequest(req, mockResponse); - // Send a second message - const message2: JSONRPCMessage = { + // Send an intermediate notification + const notification: JSONRPCMessage = { jsonrpc: "2.0", - method: "test2", - params: {}, - id: 2 + method: "progress", + params: { progress: "50%" } }; - await transport.send(message2); + await transport.send(notification, { relatedRequestId: "test-req-id" }); - // Verify the second message was received by both connections - expect(mockResponse.write).toHaveBeenCalledWith( - expect.stringContaining(JSON.stringify(message1)) - ); - expect(mockResponse2.write).toHaveBeenCalledWith( - expect.stringContaining(JSON.stringify(message2)) - ); + // Stream should remain open + expect(mockResponse.end).not.toHaveBeenCalled(); + + // Send the final response + const responseMessage: JSONRPCMessage = { + jsonrpc: "2.0", + result: { value: "test-result" }, + id: "test-req-id" + }; + + await transport.send(responseMessage, { relatedRequestId: "test-req-id" }); + + // Now stream should be closed + expect(mockResponse.end).toHaveBeenCalled(); }); }); describe("Message Targeting", () => { it("should send response messages to the connection that sent the request", async () => { - // Create two connections + // Create request with two separate connections + const requestMessage1: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test1", + params: {}, + id: "req-id-1", + }; + const mockResponse1 = createMockResponse(); const req1 = createMockRequest({ - method: "GET", + method: "POST", headers: { + "content-type": "application/json", "accept": "text/event-stream", + "mcp-session-id": transport.sessionId }, + body: JSON.stringify(requestMessage1), }); await transport.handleRequest(req1, mockResponse1); - const mockResponse2 = createMockResponse(); - const req2 = createMockRequest({ - method: "GET", - headers: { - "accept": "text/event-stream", - }, - }); - await transport.handleRequest(req2, mockResponse2); - - // Send a request through the first connection - const requestMessage: JSONRPCMessage = { + const requestMessage2: JSONRPCMessage = { jsonrpc: "2.0", - method: "test", + method: "test2", params: {}, - id: "test-id", + id: "req-id-2", }; - const reqPost = createMockRequest({ + const mockResponse2 = createMockResponse(); + const req2 = createMockRequest({ method: "POST", headers: { "content-type": "application/json", "accept": "text/event-stream", + "mcp-session-id": transport.sessionId }, - body: JSON.stringify(requestMessage), + body: JSON.stringify(requestMessage2), }); + await transport.handleRequest(req2, mockResponse2); + + // Send responses with matching IDs + const responseMessage1: JSONRPCMessage = { + jsonrpc: "2.0", + result: { success: true }, + id: "req-id-1", + }; - await transport.handleRequest(reqPost, mockResponse1); + await transport.send(responseMessage1, { relatedRequestId: "req-id-1" }); - // Send a response with matching ID - const responseMessage: JSONRPCMessage = { + const responseMessage2: JSONRPCMessage = { jsonrpc: "2.0", result: { success: true }, - id: "test-id", + id: "req-id-2", }; - await transport.send(responseMessage); + await transport.send(responseMessage2, { relatedRequestId: "req-id-2" }); - // Verify response was sent to the right connection + // Verify responses were sent to the right connections expect(mockResponse1.write).toHaveBeenCalledWith( - expect.stringContaining(JSON.stringify(responseMessage)) + expect.stringContaining(JSON.stringify(responseMessage1)) ); - // Check if write was called with this exact message on the second connection - const writeCallsOnSecondConn = mockResponse2.write.mock.calls.filter(call => - typeof call[0] === 'string' && call[0].includes(JSON.stringify(responseMessage)) + expect(mockResponse2.write).toHaveBeenCalledWith( + expect.stringContaining(JSON.stringify(responseMessage2)) ); - // Verify the response wasn't broadcast to all connections - expect(writeCallsOnSecondConn.length).toBe(0); + // Verify responses were not sent to the wrong connections + const resp1HasResp2 = mockResponse1.write.mock.calls.some(call => + typeof call[0] === 'string' && call[0].includes(JSON.stringify(responseMessage2)) + ); + expect(resp1HasResp2).toBe(false); + + const resp2HasResp1 = mockResponse2.write.mock.calls.some(call => + typeof call[0] === 'string' && call[0].includes(JSON.stringify(responseMessage1)) + ); + expect(resp2HasResp1).toBe(false); }); }); @@ -749,6 +826,7 @@ describe("StreamableHTTPServerTransport", () => { headers: { "content-type": "application/json", "accept": "application/json", + "mcp-session-id": transport.sessionId, }, body: JSON.stringify(requestBodyMessage), }); @@ -782,11 +860,18 @@ describe("StreamableHTTPServerTransport", () => { it("should include custom headers in SSE response", async () => { const req = createMockRequest({ - method: "GET", + method: "POST", headers: { + "content-type": "application/json", accept: "text/event-stream", "mcp-session-id": transportWithHeaders.sessionId }, + body: JSON.stringify({ + jsonrpc: "2.0", + method: "test", + params: {}, + id: "test-headers-id" + }) }); await transportWithHeaders.handleRequest(req, mockResponse); @@ -860,11 +945,18 @@ describe("StreamableHTTPServerTransport", () => { }); const req = createMockRequest({ - method: "GET", + method: "POST", headers: { + "content-type": "application/json", accept: "text/event-stream", "mcp-session-id": transportWithConflictingHeaders.sessionId }, + body: JSON.stringify({ + jsonrpc: "2.0", + method: "test", + params: {}, + id: "test-conflict-id" + }) }); await transportWithConflictingHeaders.handleRequest(req, mockResponse); @@ -872,7 +964,7 @@ describe("StreamableHTTPServerTransport", () => { expect(mockResponse.writeHead).toHaveBeenCalledWith( 200, expect.objectContaining({ - "Content-Type": "text/event-stream", // 应该保持原有的 Content-Type + "Content-Type": "text/event-stream", "X-Custom-Header": "custom-value" }) ); @@ -884,11 +976,18 @@ describe("StreamableHTTPServerTransport", () => { }); const req = createMockRequest({ - method: "GET", + method: "POST", headers: { + "content-type": "application/json", accept: "text/event-stream", "mcp-session-id": transportWithoutHeaders.sessionId }, + body: JSON.stringify({ + jsonrpc: "2.0", + method: "test", + params: {}, + id: "test-empty-headers-id" + }) }); await transportWithoutHeaders.handleRequest(req, mockResponse); diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts index ef654f85..6afe826a 100644 --- a/src/server/streamableHttp.ts +++ b/src/server/streamableHttp.ts @@ -1,23 +1,11 @@ -import { randomUUID } from "node:crypto"; import { IncomingMessage, ServerResponse } from "node:http"; import { Transport } from "../shared/transport.js"; -import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; +import { JSONRPCMessage, JSONRPCMessageSchema, RequestId } from "../types.js"; import getRawBody from "raw-body"; import contentType from "content-type"; const MAXIMUM_MESSAGE_SIZE = "4mb"; -interface StreamConnection { - response: ServerResponse; - lastEventId?: string; - messages: Array<{ - id: string; - message: JSONRPCMessage; - }>; - // mark this connection as a response to a specific request - requestId?: string | null; -} - /** * Configuration options for StreamableHTTPServerTransport */ @@ -34,6 +22,7 @@ export interface StreamableHTTPServerTransportOptions { * These headers will be added to both SSE and regular HTTP responses */ customHeaders?: Record; + } /** @@ -71,16 +60,11 @@ export interface StreamableHTTPServerTransportOptions { * - No session validation is performed */ export class StreamableHTTPServerTransport implements Transport { - private _connections: Map = new Map(); // when sessionID is not set, it means the transport is in stateless mode private _sessionId: string | undefined; - private _messageHistory: Map = new Map(); private _started: boolean = false; - private _requestConnections: Map = new Map(); // request ID to connection ID mapping private _customHeaders: Record; + private _sseResponseMapping: Map = new Map(); onclose?: () => void; onerror?: (error: Error) => void; @@ -108,37 +92,10 @@ export class StreamableHTTPServerTransport implements Transport { async handleRequest(req: IncomingMessage, res: ServerResponse, parsedBody?: unknown): Promise { // Only validate session ID for non-initialization requests when session management is enabled if (this._sessionId !== undefined) { - const sessionId = req.headers["mcp-session-id"]; - - // Check if this might be an initialization request const isInitializationRequest = req.method === "POST" && req.headers["content-type"]?.includes("application/json"); - if (isInitializationRequest) { - // For POST requests with JSON content, we need to check if it's an initialization request - // This will be done in handlePostRequest, as we need to parse the body - // Continue processing normally - } else if (!sessionId) { - // Non-initialization requests without a session ID should return 400 Bad Request - res.writeHead(400, this._customHeaders).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32000, - message: "Bad Request: Mcp-Session-Id header is required" - }, - id: null - })); - return; - } else if ((Array.isArray(sessionId) ? sessionId[0] : sessionId) !== this._sessionId) { - // Reject requests with invalid session ID with 404 Not Found - res.writeHead(404, this._customHeaders).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32001, - message: "Session not found" - }, - id: null - })); + if (!isInitializationRequest && !this.validateSession(req, res)) { return; } } @@ -163,66 +120,22 @@ export class StreamableHTTPServerTransport implements Transport { /** * Handles GET requests to establish SSE connections + * According to the MCP Streamable HTTP transport spec, the server MUST either return SSE or 405. + * We choose to return 405 Method Not Allowed as we don't support GET SSE connections yet. */ private async handleGetRequest(req: IncomingMessage, res: ServerResponse): Promise { - // validate the Accept header - const acceptHeader = req.headers.accept; - if (!acceptHeader || !acceptHeader.includes("text/event-stream")) { - res.writeHead(406, this._customHeaders).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32000, - message: "Not Acceptable: Client must accept text/event-stream" - }, - id: null - })); - return; - } - - const connectionId = randomUUID(); - const lastEventId = req.headers["last-event-id"]; - const lastEventIdStr = Array.isArray(lastEventId) ? lastEventId[0] : lastEventId; - - // Prepare response headers - const headers: Record = { + // Return 405 Method Not Allowed + res.writeHead(405, { ...this._customHeaders, - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - Connection: "keep-alive", - }; - - // Only include session ID header if session management is enabled by assigning a value to _sessionId - if (this._sessionId !== undefined) { - headers["mcp-session-id"] = this._sessionId; - } - - res.writeHead(200, headers); - - const connection: StreamConnection = { - response: res, - lastEventId: lastEventIdStr, - messages: [], - }; - - this._connections.set(connectionId, connection); - - // if there is a Last-Event-ID, replay messages on this connection - if (lastEventIdStr) { - this.replayMessages(connectionId, lastEventIdStr); - } - - res.on("close", () => { - this._connections.delete(connectionId); - // remove all request mappings associated with this connection - for (const [reqId, connId] of this._requestConnections.entries()) { - if (connId === connectionId) { - this._requestConnections.delete(reqId); - } - } - if (this._connections.size === 0) { - this.onclose?.(); - } - }); + "Allow": "POST, DELETE" + }).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Method not allowed: Server does not offer an SSE stream at this endpoint" + }, + id: null + })); } /** @@ -279,11 +192,18 @@ export class StreamableHTTPServerTransport implements Transport { messages = [JSONRPCMessageSchema.parse(rawMessage)]; } - // Check if this is an initialization request - // https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/lifecycle/ - const isInitializationRequest = messages.some( - msg => 'method' in msg && msg.method === 'initialize' && 'id' in msg - ); + if (this._sessionId !== undefined) { + // Check if this is an initialization request + // https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/lifecycle/ + const isInitializationRequest = messages.some( + msg => 'method' in msg && msg.method === 'initialize' && 'id' in msg + ); + + if (!isInitializationRequest && !this.validateSession(req, res)) { + return; + } + } + // check if it contains requests const hasRequests = messages.some(msg => 'method' in msg && 'id' in msg); @@ -310,40 +230,29 @@ export class StreamableHTTPServerTransport implements Transport { Connection: "keep-alive", }; + // For initialization requests, always include the session ID if we have one + // even if we're in stateless mode if (this._sessionId !== undefined) { headers["mcp-session-id"] = this._sessionId; } res.writeHead(200, headers); - const connectionId = randomUUID(); - const connection: StreamConnection = { - response: res, - messages: [], - }; - - this._connections.set(connectionId, connection); - - // map each request to a connection ID + // Store the response for this request to send messages back through this connection + // We need to track by request ID to maintain the connection for (const message of messages) { if ('method' in message && 'id' in message) { - this._requestConnections.set(String(message.id), connectionId); + this._sseResponseMapping.set(message.id, res); } + } + + // handle each message + for (const message of messages) { this.onmessage?.(message); } - res.on("close", () => { - this._connections.delete(connectionId); - // remove all request mappings associated with this connection - for (const [reqId, connId] of this._requestConnections.entries()) { - if (connId === connectionId) { - this._requestConnections.delete(reqId); - } - } - if (this._connections.size === 0) { - this.onclose?.(); - } - }); + // The server SHOULD NOT close the SSE stream before sending all JSON-RPC responses + // This will be handled by the send() method when responses are ready } else { // use direct JSON response const headers: Record = { @@ -351,7 +260,8 @@ export class StreamableHTTPServerTransport implements Transport { "Content-Type": "application/json", }; - + // For initialization requests, always include the session ID if we have one + // even if we're in stateless mode if (this._sessionId !== undefined) { headers["mcp-session-id"] = this._sessionId; } @@ -390,97 +300,73 @@ export class StreamableHTTPServerTransport implements Transport { } /** - * Replays messages after the specified event ID for a specific connection + * Validates session ID for non-initialization requests when session management is enabled + * Returns true if the session is valid, false otherwise */ - private replayMessages(connectionId: string, lastEventId: string): void { - if (!lastEventId) return; - - // only replay messages that should be sent on this connection - const messages = Array.from(this._messageHistory.entries()) - .filter(([id, { connectionId: msgConnId }]) => - id > lastEventId && - (!msgConnId || msgConnId === connectionId)) // only replay messages that are not specified to a connection or specified to the current connection - .sort(([a], [b]) => a.localeCompare(b)); - - const connection = this._connections.get(connectionId); - if (!connection) return; - - for (const [id, { message }] of messages) { - connection.response.write( - `id: ${id}\nevent: message\ndata: ${JSON.stringify(message)}\n\n` - ); + private validateSession(req: IncomingMessage, res: ServerResponse): boolean { + const sessionId = req.headers["mcp-session-id"]; + + if (!sessionId) { + // Non-initialization requests without a session ID should return 400 Bad Request + res.writeHead(400, this._customHeaders).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Bad Request: Mcp-Session-Id header is required" + }, + id: null + })); + return false; + } else if ((Array.isArray(sessionId) ? sessionId[0] : sessionId) !== this._sessionId) { + // Reject requests with invalid session ID with 404 Not Found + res.writeHead(404, this._customHeaders).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32001, + message: "Session not found" + }, + id: null + })); + return false; } + + return true; } + async close(): Promise { - for (const connection of this._connections.values()) { - connection.response.end(); - } - this._connections.clear(); - this._messageHistory.clear(); - this._requestConnections.clear(); + // Close all SSE connections + this._sseResponseMapping.forEach((response) => { + response.end(); + }); + this._sseResponseMapping.clear(); this.onclose?.(); } - async send(message: JSONRPCMessage): Promise { - if (this._connections.size === 0) { - throw new Error("No active connections"); + async send(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId }): Promise { + const relatedRequestId = options?.relatedRequestId; + if (relatedRequestId === undefined) { + throw new Error("relatedRequestId is required"); } - let targetConnectionId = ""; - - // if it is a response, find the corresponding request connection - if ('id' in message && ('result' in message || 'error' in message)) { - const connId = this._requestConnections.get(String(message.id)); - - // if the corresponding connection is not found, the connection may be disconnected - if (!connId || !this._connections.has(connId)) { - // select an available connection - const firstConnId = this._connections.keys().next().value; - if (firstConnId) { - targetConnectionId = firstConnId; - } else { - throw new Error("No available connections"); - } - } else { - targetConnectionId = connId; - } - } else { - // for other messages, select an available connection - const firstConnId = this._connections.keys().next().value; - if (firstConnId) { - targetConnectionId = firstConnId; - } else { - throw new Error("No available connections"); - } - } - - const messageId = randomUUID(); - this._messageHistory.set(messageId, { - message, - connectionId: targetConnectionId - }); - - // keep the message history in a reasonable range - if (this._messageHistory.size > 1000) { - const oldestKey = Array.from(this._messageHistory.keys())[0]; - this._messageHistory.delete(oldestKey); + const sseResponse = this._sseResponseMapping.get(relatedRequestId); + if (!sseResponse) { + throw new Error("No SSE connection established"); } - // send the message to all active connections - for (const [connId, connection] of this._connections.entries()) { - // if it is a response message, only send to the target connection - if ('id' in message && ('result' in message || 'error' in message)) { - if (connId === targetConnectionId) { - connection.response.write( - `id: ${messageId}\nevent: message\ndata: ${JSON.stringify(message)}\n\n` - ); - } - } else { - // for other messages, send to all connections - connection.response.write( - `id: ${messageId}\nevent: message\ndata: ${JSON.stringify(message)}\n\n` - ); + // Send the message as an SSE event + sseResponse.write( + `event: message\ndata: ${JSON.stringify(message)}\n\n`, + ); + + // If this is a response message with the same ID as the request, we can check + // if we need to close the stream after sending the response + if ('result' in message || 'error' in message) { + if (message.id === relatedRequestId) { + // This is a response to the original request, we can close the stream + // after sending all related responses + this._sseResponseMapping.delete(relatedRequestId); + sseResponse.end(); } } } diff --git a/src/shared/transport.ts b/src/shared/transport.ts index 72db1da8..d855ceab 100644 --- a/src/shared/transport.ts +++ b/src/shared/transport.ts @@ -16,7 +16,7 @@ export interface Transport { /** * Sends a JSON-RPC message (request or response). * - * If present, `relatedRequestId` is used to indicate to the transport which incoming request to associate this outgoing message with. + * If present, `relatedRequestId` is used to indicate to the transport which incoming request to associate this outgoing message with. */ send(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId }): Promise; From 4068e6f9a3da7f5e00dec06901fddee01abc5a65 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Mon, 7 Apr 2025 14:22:11 +0100 Subject: [PATCH 12/17] cleanup initialization --- src/server/streamableHttp.test.ts | 90 ++++++------------ src/server/streamableHttp.ts | 146 ++++++++++++++---------------- 2 files changed, 97 insertions(+), 139 deletions(-) diff --git a/src/server/streamableHttp.test.ts b/src/server/streamableHttp.test.ts index 8fb55e79..129b499e 100644 --- a/src/server/streamableHttp.test.ts +++ b/src/server/streamableHttp.test.ts @@ -73,7 +73,7 @@ describe("StreamableHTTPServerTransport", () => { method: "POST", headers: { "content-type": "application/json", - "accept": "application/json", + "accept": "application/json, text/event-stream", }, body: JSON.stringify(initializeMessage), }); @@ -93,7 +93,7 @@ describe("StreamableHTTPServerTransport", () => { method: "GET", headers: { "mcp-session-id": "invalid-session-id", - "accept": "text/event-stream" + "accept": "application/json, text/event-stream" }, }); @@ -109,7 +109,7 @@ describe("StreamableHTTPServerTransport", () => { const req = createMockRequest({ method: "GET", headers: { - accept: "text/event-stream", + accept: "application/json, text/event-stream", // No mcp-session-id header }, }); @@ -143,7 +143,7 @@ describe("StreamableHTTPServerTransport", () => { method: "POST", headers: { "content-type": "application/json", - "accept": "application/json", + "accept": "application/json, text/event-stream", }, body: JSON.stringify(message), }); @@ -161,7 +161,7 @@ describe("StreamableHTTPServerTransport", () => { method: "POST", headers: { "content-type": "application/json", - "accept": "application/json", + "accept": "application/json, text/event-stream", "mcp-session-id": "invalid-session-id", // This would cause a 404 in stateful mode }, body: JSON.stringify({ @@ -195,7 +195,7 @@ describe("StreamableHTTPServerTransport", () => { method: "POST", headers: { "content-type": "application/json", - "accept": "application/json", + "accept": "application/json, text/event-stream", "mcp-session-id": "non-existent-session-id", // This would be rejected in stateful mode }, body: JSON.stringify(message), @@ -216,7 +216,7 @@ describe("StreamableHTTPServerTransport", () => { method: "POST", headers: { "content-type": "application/json", - accept: "text/event-stream", + accept: "application/json, text/event-stream", }, body: JSON.stringify({ jsonrpc: "2.0", @@ -242,7 +242,7 @@ describe("StreamableHTTPServerTransport", () => { method: "POST", headers: { "content-type": "application/json", - accept: "text/event-stream", + accept: "application/json, text/event-stream", "mcp-session-id": "some-random-session-id", }, body: JSON.stringify({ @@ -264,7 +264,7 @@ describe("StreamableHTTPServerTransport", () => { ); }); - it("should handle initialization requests properly in statefull mode", async () => { + it("should handle initialization requests properly in stateful mode", async () => { // Initialize message that would typically be sent during initialization const initializeMessage: JSONRPCMessage = { jsonrpc: "2.0", @@ -281,7 +281,7 @@ describe("StreamableHTTPServerTransport", () => { method: "POST", headers: { "content-type": "application/json", - "accept": "application/json", + "accept": "application/json, text/event-stream", }, body: JSON.stringify(initializeMessage), }); @@ -314,7 +314,7 @@ describe("StreamableHTTPServerTransport", () => { method: "POST", headers: { "content-type": "application/json", - "accept": "application/json", + "accept": "application/json, text/event-stream", }, body: JSON.stringify(initializeMessage), }); @@ -333,7 +333,7 @@ describe("StreamableHTTPServerTransport", () => { const req = createMockRequest({ method: "GET", headers: { - "accept": "text/event-stream", + "accept": "application/json, text/event-stream", "mcp-session-id": transport.sessionId, }, }); @@ -380,7 +380,7 @@ describe("StreamableHTTPServerTransport", () => { method: "POST", headers: { "content-type": "application/json", - "accept": "text/event-stream", + "accept": "application/json, text/event-stream", }, body: JSON.stringify(message), }); @@ -394,7 +394,7 @@ describe("StreamableHTTPServerTransport", () => { expect(mockResponse.writeHead).toHaveBeenCalledWith( 200, expect.objectContaining({ - "Content-Type": "text/event-stream", + "mcp-session-id": transport.sessionId, }) ); }); @@ -435,7 +435,7 @@ describe("StreamableHTTPServerTransport", () => { method: "POST", headers: { "content-type": "application/json", - "accept": "application/json", + "accept": "application/json, text/event-stream", "mcp-session-id": transport.sessionId, }, body: JSON.stringify(batchMessages), @@ -488,7 +488,7 @@ describe("StreamableHTTPServerTransport", () => { method: "POST", headers: { "content-type": "text/plain", - "accept": "application/json", + "accept": "application/json, text/event-stream", "mcp-session-id": transport.sessionId, }, body: "test", @@ -532,7 +532,7 @@ describe("StreamableHTTPServerTransport", () => { method: "POST", headers: { "content-type": "application/json", - "accept": "text/event-stream", + "accept": "application/json, text/event-stream", "mcp-session-id": transport.sessionId }, body: JSON.stringify(requestMessage) @@ -571,7 +571,7 @@ describe("StreamableHTTPServerTransport", () => { method: "POST", headers: { "content-type": "application/json", - "accept": "text/event-stream", + "accept": "application/json, text/event-stream", "mcp-session-id": transport.sessionId }, body: JSON.stringify(requestMessage) @@ -620,7 +620,7 @@ describe("StreamableHTTPServerTransport", () => { method: "POST", headers: { "content-type": "application/json", - "accept": "text/event-stream", + "accept": "application/json, text/event-stream", "mcp-session-id": transport.sessionId }, body: JSON.stringify(requestMessage1), @@ -639,7 +639,7 @@ describe("StreamableHTTPServerTransport", () => { method: "POST", headers: { "content-type": "application/json", - "accept": "text/event-stream", + "accept": "application/json, text/event-stream", "mcp-session-id": transport.sessionId }, body: JSON.stringify(requestMessage2), @@ -691,7 +691,7 @@ describe("StreamableHTTPServerTransport", () => { method: "POST", headers: { "content-type": "application/json", - "accept": "application/json", + "accept": "application/json, text/event-stream", }, body: "invalid json", }); @@ -712,7 +712,7 @@ describe("StreamableHTTPServerTransport", () => { method: "POST", headers: { "content-type": "application/json", - "accept": "application/json", + "accept": "application/json, text/event-stream", }, body: JSON.stringify({ invalid: "message" }), }); @@ -745,7 +745,7 @@ describe("StreamableHTTPServerTransport", () => { method: "POST", headers: { "content-type": "application/json", - "accept": "application/json", + "accept": "application/json, text/event-stream", }, // No body provided here - it will be passed as parsedBody }); @@ -761,7 +761,7 @@ describe("StreamableHTTPServerTransport", () => { expect(mockResponse.writeHead).toHaveBeenCalledWith( 200, expect.objectContaining({ - "Content-Type": "application/json", + "mcp-session-id": transport.sessionId, }) ); }); @@ -787,7 +787,7 @@ describe("StreamableHTTPServerTransport", () => { method: "POST", headers: { "content-type": "application/json", - "accept": "text/event-stream", + "accept": "application/json, text/event-stream", "mcp-session-id": transport.sessionId, }, // No body provided here - it will be passed as parsedBody @@ -825,7 +825,7 @@ describe("StreamableHTTPServerTransport", () => { method: "POST", headers: { "content-type": "application/json", - "accept": "application/json", + "accept": "application/json, text/event-stream", "mcp-session-id": transport.sessionId, }, body: JSON.stringify(requestBodyMessage), @@ -863,7 +863,7 @@ describe("StreamableHTTPServerTransport", () => { method: "POST", headers: { "content-type": "application/json", - accept: "text/event-stream", + accept: "application/json, text/event-stream", "mcp-session-id": transportWithHeaders.sessionId }, body: JSON.stringify({ @@ -888,41 +888,11 @@ describe("StreamableHTTPServerTransport", () => { ); }); - it("should include custom headers in JSON response", async () => { - const message: JSONRPCMessage = { - jsonrpc: "2.0", - method: "test", - params: {}, - id: 1, - }; - - const req = createMockRequest({ - method: "POST", - headers: { - "content-type": "application/json", - "accept": "application/json", - "mcp-session-id": transportWithHeaders.sessionId - }, - body: JSON.stringify(message), - }); - - await transportWithHeaders.handleRequest(req, mockResponse); - - expect(mockResponse.writeHead).toHaveBeenCalledWith( - 200, - expect.objectContaining({ - ...customHeaders, - "Content-Type": "application/json", - "mcp-session-id": transportWithHeaders.sessionId - }) - ); - }); - it("should include custom headers in error responses", async () => { const req = createMockRequest({ method: "GET", headers: { - accept: "text/event-stream", + accept: "application/json, text/event-stream", "mcp-session-id": "invalid-session-id" }, }); @@ -948,7 +918,7 @@ describe("StreamableHTTPServerTransport", () => { method: "POST", headers: { "content-type": "application/json", - accept: "text/event-stream", + accept: "application/json, text/event-stream", "mcp-session-id": transportWithConflictingHeaders.sessionId }, body: JSON.stringify({ @@ -979,7 +949,7 @@ describe("StreamableHTTPServerTransport", () => { method: "POST", headers: { "content-type": "application/json", - accept: "text/event-stream", + accept: "application/json, text/event-stream", "mcp-session-id": transportWithoutHeaders.sessionId }, body: JSON.stringify({ diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts index 6afe826a..984b2c19 100644 --- a/src/server/streamableHttp.ts +++ b/src/server/streamableHttp.ts @@ -37,7 +37,7 @@ export interface StreamableHTTPServerTransportOptions { * sessionId: randomUUID(), * }); * - * // Stateless mode - explisitly set session ID to undefined + * // Stateless mode - explicitly set session ID to undefined * const statelessTransport = new StreamableHTTPServerTransport({ * sessionId: undefined, * }); @@ -60,7 +60,7 @@ export interface StreamableHTTPServerTransportOptions { * - No session validation is performed */ export class StreamableHTTPServerTransport implements Transport { - // when sessionID is not set, it means the transport is in stateless mode + // when sessionId is not set (undefined), it means the transport is in stateless mode private _sessionId: string | undefined; private _started: boolean = false; private _customHeaders: Record; @@ -71,8 +71,8 @@ export class StreamableHTTPServerTransport implements Transport { onmessage?: (message: JSONRPCMessage) => void; constructor(options: StreamableHTTPServerTransportOptions) { - this._sessionId = options?.sessionId; - this._customHeaders = options?.customHeaders || {}; + this._sessionId = options.sessionId; + this._customHeaders = options.customHeaders || {}; } /** @@ -90,16 +90,6 @@ export class StreamableHTTPServerTransport implements Transport { * Handles an incoming HTTP request, whether GET or POST */ async handleRequest(req: IncomingMessage, res: ServerResponse, parsedBody?: unknown): Promise { - // Only validate session ID for non-initialization requests when session management is enabled - if (this._sessionId !== undefined) { - const isInitializationRequest = req.method === "POST" && - req.headers["content-type"]?.includes("application/json"); - - if (!isInitializationRequest && !this.validateSession(req, res)) { - return; - } - } - if (req.method === "GET") { await this.handleGetRequest(req, res); } else if (req.method === "POST") { @@ -124,7 +114,11 @@ export class StreamableHTTPServerTransport implements Transport { * We choose to return 405 Method Not Allowed as we don't support GET SSE connections yet. */ private async handleGetRequest(req: IncomingMessage, res: ServerResponse): Promise { - // Return 405 Method Not Allowed + // Check session validity first for GET requests when session management is enabled + if (this._sessionId !== undefined && !this.validateSession(req, res)) { + return; + } + res.writeHead(405, { ...this._customHeaders, "Allow": "POST, DELETE" @@ -143,15 +137,16 @@ export class StreamableHTTPServerTransport implements Transport { */ private async handlePostRequest(req: IncomingMessage, res: ServerResponse, parsedBody?: unknown): Promise { try { - // validate the Accept header + // Validate the Accept header const acceptHeader = req.headers.accept; + // The client MUST include an Accept header, listing both application/json and text/event-stream as supported content types. if (!acceptHeader || - (!acceptHeader.includes("application/json") && !acceptHeader.includes("text/event-stream"))) { + !acceptHeader.includes("application/json") || !acceptHeader.includes("text/event-stream")) { res.writeHead(406).end(JSON.stringify({ jsonrpc: "2.0", error: { code: -32000, - message: "Not Acceptable: Client must accept application/json and/or text/event-stream" + message: "Not Acceptable: Client must accept both application/json and text/event-stream" }, id: null })); @@ -192,16 +187,33 @@ export class StreamableHTTPServerTransport implements Transport { messages = [JSONRPCMessageSchema.parse(rawMessage)]; } - if (this._sessionId !== undefined) { - // Check if this is an initialization request - // https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/lifecycle/ - const isInitializationRequest = messages.some( - msg => 'method' in msg && msg.method === 'initialize' && 'id' in msg - ); + // Check if this is an initialization request + // https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/lifecycle/ + const isInitializationRequest = messages.some( + msg => 'method' in msg && msg.method === 'initialize' + ); + if (isInitializationRequest) { + const headers: Record = { + ...this._customHeaders + }; + + if (this._sessionId !== undefined) { + headers["mcp-session-id"] = this._sessionId; + } - if (!isInitializationRequest && !this.validateSession(req, res)) { - return; + // Process initialization messages before responding + for (const message of messages) { + this.onmessage?.(message); } + + res.writeHead(200, headers).end(); + return; + } + // If an Mcp-Session-Id is returned by the server during initialization, + // clients using the Streamable HTTP transport MUST include it + // in the Mcp-Session-Id header on all of their subsequent HTTP requests. + if (this._sessionId !== undefined && !isInitializationRequest && !this.validateSession(req, res)) { + return; } @@ -219,62 +231,35 @@ export class StreamableHTTPServerTransport implements Transport { this.onmessage?.(message); } } else if (hasRequests) { - // if it contains requests, you can choose to return an SSE stream or a JSON response - const useSSE = acceptHeader.includes("text/event-stream"); - - if (useSSE) { - const headers: Record = { - ...this._customHeaders, - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - Connection: "keep-alive", - }; - - // For initialization requests, always include the session ID if we have one - // even if we're in stateless mode - if (this._sessionId !== undefined) { - headers["mcp-session-id"] = this._sessionId; - } - - res.writeHead(200, headers); - - // Store the response for this request to send messages back through this connection - // We need to track by request ID to maintain the connection - for (const message of messages) { - if ('method' in message && 'id' in message) { - this._sseResponseMapping.set(message.id, res); - } - } - - // handle each message - for (const message of messages) { - this.onmessage?.(message); - } - - // The server SHOULD NOT close the SSE stream before sending all JSON-RPC responses - // This will be handled by the send() method when responses are ready - } else { - // use direct JSON response - const headers: Record = { - ...this._customHeaders, - "Content-Type": "application/json", - }; - - // For initialization requests, always include the session ID if we have one - // even if we're in stateless mode - if (this._sessionId !== undefined) { - headers["mcp-session-id"] = this._sessionId; - } + const headers: Record = { + ...this._customHeaders, + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }; + + // For initialization requests, always include the session ID if we have one + // even if we're in stateless mode + if (this._sessionId !== undefined) { + headers["mcp-session-id"] = this._sessionId; + } - res.writeHead(200, headers); + res.writeHead(200, headers); - // handle each message - for (const message of messages) { - this.onmessage?.(message); + // Store the response for this request to send messages back through this connection + // We need to track by request ID to maintain the connection + for (const message of messages) { + if ('method' in message && 'id' in message) { + this._sseResponseMapping.set(message.id, res); } + } - res.end(); + // handle each message + for (const message of messages) { + this.onmessage?.(message); } + // The server SHOULD NOT close the SSE stream before sending all JSON-RPC responses + // This will be handled by the send() method when responses are ready } } catch (error) { // return JSON-RPC formatted error @@ -295,6 +280,9 @@ export class StreamableHTTPServerTransport implements Transport { * Handles DELETE requests to terminate sessions */ private async handleDeleteRequest(req: IncomingMessage, res: ServerResponse): Promise { + if (!this.validateSession(req, res)) { + return; + } await this.close(); res.writeHead(200).end(); } @@ -346,12 +334,12 @@ export class StreamableHTTPServerTransport implements Transport { async send(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId }): Promise { const relatedRequestId = options?.relatedRequestId; if (relatedRequestId === undefined) { - throw new Error("relatedRequestId is required"); + throw new Error("relatedRequestId is required for Streamable HTTP transport"); } const sseResponse = this._sseResponseMapping.get(relatedRequestId); if (!sseResponse) { - throw new Error("No SSE connection established"); + throw new Error(`No SSE connection established for request ID: ${String(relatedRequestId)}`); } // Send the message as an SSE event From af64a0e18cccc536cc3c35061ff1d0955cc102c8 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Mon, 7 Apr 2025 16:03:31 +0100 Subject: [PATCH 13/17] remove custom headers --- src/server/mcp.test.ts | 11 +-- src/server/streamableHttp.test.ts | 135 +----------------------------- src/server/streamableHttp.ts | 25 ++---- 3 files changed, 11 insertions(+), 160 deletions(-) diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index 0c136e29..f33c669f 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -379,7 +379,7 @@ describe("tool()", () => { expect(receivedSessionId).toBe("test-session-123"); }); - test("should provide sendNotification withing tool call", async () => { + test("should provide sendNotification within tool call", async () => { const mcpServer = new McpServer( { name: "test server", @@ -401,12 +401,10 @@ describe("tool()", () => { ); let receivedLogMessage: string | undefined; - - const loggingMessage = "hello here is log message 1" + const loggingMessage = "hello here is log message 1"; client.setNotificationHandler(LoggingMessageNotificationSchema, (notification) => { receivedLogMessage = notification.params.data as string; - }); mcpServer.tool("test-tool", async ({ sendNotification }) => { @@ -422,15 +420,10 @@ describe("tool()", () => { }); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - // Set a test sessionId on the server transport - serverTransport.sessionId = "test-session-123"; - - await Promise.all([ client.connect(clientTransport), mcpServer.server.connect(serverTransport), ]); - await client.request( { method: "tools/call", diff --git a/src/server/streamableHttp.test.ts b/src/server/streamableHttp.test.ts index 129b499e..6c5c8705 100644 --- a/src/server/streamableHttp.test.ts +++ b/src/server/streamableHttp.test.ts @@ -99,7 +99,7 @@ describe("StreamableHTTPServerTransport", () => { await transport.handleRequest(req, mockResponse); - expect(mockResponse.writeHead).toHaveBeenCalledWith(404, {}); + expect(mockResponse.writeHead).toHaveBeenCalledWith(404); // check if the error response is a valid JSON-RPC error format expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"error"')); @@ -116,7 +116,7 @@ describe("StreamableHTTPServerTransport", () => { await transport.handleRequest(req, mockResponse); - expect(mockResponse.writeHead).toHaveBeenCalledWith(400, {}); + expect(mockResponse.writeHead).toHaveBeenCalledWith(400); expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"message":"Bad Request: Mcp-Session-Id header is required"')); }); @@ -842,135 +842,4 @@ describe("StreamableHTTPServerTransport", () => { expect(onMessageMock).not.toHaveBeenCalledWith(requestBodyMessage); }); }); - - describe("Custom Headers", () => { - const customHeaders = { - "X-Custom-Header": "custom-value", - "X-API-Version": "1.0", - "Access-Control-Allow-Origin": "*" - }; - - let transportWithHeaders: StreamableHTTPServerTransport; - let mockResponse: jest.Mocked; - - beforeEach(() => { - transportWithHeaders = new StreamableHTTPServerTransport({ sessionId: randomUUID(), customHeaders }); - mockResponse = createMockResponse(); - }); - - it("should include custom headers in SSE response", async () => { - const req = createMockRequest({ - method: "POST", - headers: { - "content-type": "application/json", - accept: "application/json, text/event-stream", - "mcp-session-id": transportWithHeaders.sessionId - }, - body: JSON.stringify({ - jsonrpc: "2.0", - method: "test", - params: {}, - id: "test-headers-id" - }) - }); - - await transportWithHeaders.handleRequest(req, mockResponse); - - expect(mockResponse.writeHead).toHaveBeenCalledWith( - 200, - expect.objectContaining({ - ...customHeaders, - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "mcp-session-id": transportWithHeaders.sessionId - }) - ); - }); - - it("should include custom headers in error responses", async () => { - const req = createMockRequest({ - method: "GET", - headers: { - accept: "application/json, text/event-stream", - "mcp-session-id": "invalid-session-id" - }, - }); - - await transportWithHeaders.handleRequest(req, mockResponse); - - expect(mockResponse.writeHead).toHaveBeenCalledWith( - 404, - expect.objectContaining(customHeaders) - ); - }); - - it("should not override essential headers with custom headers", async () => { - const transportWithConflictingHeaders = new StreamableHTTPServerTransport({ - sessionId: randomUUID(), - customHeaders: { - "Content-Type": "text/plain", - "X-Custom-Header": "custom-value" - } - }); - - const req = createMockRequest({ - method: "POST", - headers: { - "content-type": "application/json", - accept: "application/json, text/event-stream", - "mcp-session-id": transportWithConflictingHeaders.sessionId - }, - body: JSON.stringify({ - jsonrpc: "2.0", - method: "test", - params: {}, - id: "test-conflict-id" - }) - }); - - await transportWithConflictingHeaders.handleRequest(req, mockResponse); - - expect(mockResponse.writeHead).toHaveBeenCalledWith( - 200, - expect.objectContaining({ - "Content-Type": "text/event-stream", - "X-Custom-Header": "custom-value" - }) - ); - }); - - it("should work with empty custom headers", async () => { - const transportWithoutHeaders = new StreamableHTTPServerTransport({ - sessionId: randomUUID(), - }); - - const req = createMockRequest({ - method: "POST", - headers: { - "content-type": "application/json", - accept: "application/json, text/event-stream", - "mcp-session-id": transportWithoutHeaders.sessionId - }, - body: JSON.stringify({ - jsonrpc: "2.0", - method: "test", - params: {}, - id: "test-empty-headers-id" - }) - }); - - await transportWithoutHeaders.handleRequest(req, mockResponse); - - expect(mockResponse.writeHead).toHaveBeenCalledWith( - 200, - expect.objectContaining({ - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "mcp-session-id": transportWithoutHeaders.sessionId - }) - ); - }); - }); }); \ No newline at end of file diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts index 984b2c19..d7f44160 100644 --- a/src/server/streamableHttp.ts +++ b/src/server/streamableHttp.ts @@ -13,15 +13,11 @@ export interface StreamableHTTPServerTransportOptions { /** * The session ID SHOULD be globally unique and cryptographically secure (e.g., a securely generated UUID, a JWT, or a cryptographic hash) * - * When sessionId is not set, the transport will be in stateless mode. + * When there is no sessionId, the transport will not perform session management. */ sessionId: string | undefined; - /** - * Custom headers to be included in all responses - * These headers will be added to both SSE and regular HTTP responses - */ - customHeaders?: Record; + } @@ -63,7 +59,6 @@ export class StreamableHTTPServerTransport implements Transport { // when sessionId is not set (undefined), it means the transport is in stateless mode private _sessionId: string | undefined; private _started: boolean = false; - private _customHeaders: Record; private _sseResponseMapping: Map = new Map(); onclose?: () => void; @@ -72,7 +67,6 @@ export class StreamableHTTPServerTransport implements Transport { constructor(options: StreamableHTTPServerTransportOptions) { this._sessionId = options.sessionId; - this._customHeaders = options.customHeaders || {}; } /** @@ -97,7 +91,7 @@ export class StreamableHTTPServerTransport implements Transport { } else if (req.method === "DELETE") { await this.handleDeleteRequest(req, res); } else { - res.writeHead(405, this._customHeaders).end(JSON.stringify({ + res.writeHead(405).end(JSON.stringify({ jsonrpc: "2.0", error: { code: -32000, @@ -120,7 +114,6 @@ export class StreamableHTTPServerTransport implements Transport { } res.writeHead(405, { - ...this._customHeaders, "Allow": "POST, DELETE" }).end(JSON.stringify({ jsonrpc: "2.0", @@ -193,9 +186,7 @@ export class StreamableHTTPServerTransport implements Transport { msg => 'method' in msg && msg.method === 'initialize' ); if (isInitializationRequest) { - const headers: Record = { - ...this._customHeaders - }; + const headers: Record = {}; if (this._sessionId !== undefined) { headers["mcp-session-id"] = this._sessionId; @@ -232,14 +223,12 @@ export class StreamableHTTPServerTransport implements Transport { } } else if (hasRequests) { const headers: Record = { - ...this._customHeaders, "Content-Type": "text/event-stream", "Cache-Control": "no-cache", Connection: "keep-alive", }; - // For initialization requests, always include the session ID if we have one - // even if we're in stateless mode + // After initialization, always include the session ID if we have one if (this._sessionId !== undefined) { headers["mcp-session-id"] = this._sessionId; } @@ -296,7 +285,7 @@ export class StreamableHTTPServerTransport implements Transport { if (!sessionId) { // Non-initialization requests without a session ID should return 400 Bad Request - res.writeHead(400, this._customHeaders).end(JSON.stringify({ + res.writeHead(400).end(JSON.stringify({ jsonrpc: "2.0", error: { code: -32000, @@ -307,7 +296,7 @@ export class StreamableHTTPServerTransport implements Transport { return false; } else if ((Array.isArray(sessionId) ? sessionId[0] : sessionId) !== this._sessionId) { // Reject requests with invalid session ID with 404 Not Found - res.writeHead(404, this._customHeaders).end(JSON.stringify({ + res.writeHead(404).end(JSON.stringify({ jsonrpc: "2.0", error: { code: -32001, From 8c2086ec975a94b778662ca792c6ef7ecc9c616b Mon Sep 17 00:00:00 2001 From: ihrpr Date: Mon, 7 Apr 2025 16:15:09 +0100 Subject: [PATCH 14/17] refactor handle request types --- src/server/streamableHttp.test.ts | 2 +- src/server/streamableHttp.ts | 25 ++++++++----------------- 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/src/server/streamableHttp.test.ts b/src/server/streamableHttp.test.ts index 6c5c8705..981ae17b 100644 --- a/src/server/streamableHttp.test.ts +++ b/src/server/streamableHttp.test.ts @@ -344,7 +344,7 @@ describe("StreamableHTTPServerTransport", () => { "Allow": "POST, DELETE" })); expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); - expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('Server does not offer an SSE stream at this endpoint')); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('Method not allowed')); }); it("should reject POST requests without proper Accept header", async () => { diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts index d7f44160..a58370f1 100644 --- a/src/server/streamableHttp.ts +++ b/src/server/streamableHttp.ts @@ -84,31 +84,20 @@ export class StreamableHTTPServerTransport implements Transport { * Handles an incoming HTTP request, whether GET or POST */ async handleRequest(req: IncomingMessage, res: ServerResponse, parsedBody?: unknown): Promise { - if (req.method === "GET") { - await this.handleGetRequest(req, res); - } else if (req.method === "POST") { + if (req.method === "POST") { await this.handlePostRequest(req, res, parsedBody); } else if (req.method === "DELETE") { await this.handleDeleteRequest(req, res); } else { - res.writeHead(405).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32000, - message: "Method not allowed" - }, - id: null - })); + await this.handleUnsupportedRequest(req, res); } } /** - * Handles GET requests to establish SSE connections - * According to the MCP Streamable HTTP transport spec, the server MUST either return SSE or 405. - * We choose to return 405 Method Not Allowed as we don't support GET SSE connections yet. + * Handles unsupported requests (GET, PUT, PATCH, etc.) + * For now we support only POST and DELETE requests. Support for GET for SSE connections will be added later. */ - private async handleGetRequest(req: IncomingMessage, res: ServerResponse): Promise { - // Check session validity first for GET requests when session management is enabled + private async handleUnsupportedRequest(req: IncomingMessage, res: ServerResponse): Promise { if (this._sessionId !== undefined && !this.validateSession(req, res)) { return; } @@ -119,7 +108,7 @@ export class StreamableHTTPServerTransport implements Transport { jsonrpc: "2.0", error: { code: -32000, - message: "Method not allowed: Server does not offer an SSE stream at this endpoint" + message: "Method not allowed." }, id: null })); @@ -322,6 +311,8 @@ export class StreamableHTTPServerTransport implements Transport { async send(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId }): Promise { const relatedRequestId = options?.relatedRequestId; + // SSE connections are established per POST request, for now we don't support it through the GET + // this will be changed when we implement the GET SSE connection if (relatedRequestId === undefined) { throw new Error("relatedRequestId is required for Streamable HTTP transport"); } From c275a0d24eaed78eba86f013c5e5463ff4ce882a Mon Sep 17 00:00:00 2001 From: ihrpr Date: Mon, 7 Apr 2025 16:27:28 +0100 Subject: [PATCH 15/17] remove session validation in handleUnsupportedRequest --- src/server/streamableHttp.test.ts | 19 +++++++++++++++---- src/server/streamableHttp.ts | 3 --- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/server/streamableHttp.test.ts b/src/server/streamableHttp.test.ts index 981ae17b..6062a6c9 100644 --- a/src/server/streamableHttp.test.ts +++ b/src/server/streamableHttp.test.ts @@ -39,12 +39,19 @@ function createMockResponse(): jest.Mocked { describe("StreamableHTTPServerTransport", () => { let transport: StreamableHTTPServerTransport; let mockResponse: jest.Mocked; + let mockRequest: string; beforeEach(() => { transport = new StreamableHTTPServerTransport({ sessionId: randomUUID(), }); mockResponse = createMockResponse(); + mockRequest = JSON.stringify({ + jsonrpc: "2.0", + method: "test", + params: {}, + id: 1, + }); }); afterEach(() => { @@ -90,11 +97,13 @@ describe("StreamableHTTPServerTransport", () => { it("should reject invalid session ID", async () => { const req = createMockRequest({ - method: "GET", + method: "POST", headers: { "mcp-session-id": "invalid-session-id", - "accept": "application/json, text/event-stream" + "accept": "application/json, text/event-stream", + "content-type": "application/json", }, + body: mockRequest, }); await transport.handleRequest(req, mockResponse); @@ -107,11 +116,13 @@ describe("StreamableHTTPServerTransport", () => { it("should reject non-initialization requests without session ID with 400 Bad Request", async () => { const req = createMockRequest({ - method: "GET", + method: "POST", headers: { - accept: "application/json, text/event-stream", + "accept": "application/json, text/event-stream", + "content-type": "application/json", // No mcp-session-id header }, + body: mockRequest }); await transport.handleRequest(req, mockResponse); diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts index a58370f1..d0d75a65 100644 --- a/src/server/streamableHttp.ts +++ b/src/server/streamableHttp.ts @@ -98,9 +98,6 @@ export class StreamableHTTPServerTransport implements Transport { * For now we support only POST and DELETE requests. Support for GET for SSE connections will be added later. */ private async handleUnsupportedRequest(req: IncomingMessage, res: ServerResponse): Promise { - if (this._sessionId !== undefined && !this.validateSession(req, res)) { - return; - } res.writeHead(405, { "Allow": "POST, DELETE" From f29cbe7e06153092dc8afdfd37169a2553b51dfb Mon Sep 17 00:00:00 2001 From: ihrpr Date: Mon, 7 Apr 2025 16:49:46 +0100 Subject: [PATCH 16/17] reject session as an array --- src/server/streamableHttp.ts | 39 +++++++++++++++++++++++++++++------- src/shared/protocol.ts | 8 ++++---- 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts index d0d75a65..8ad821de 100644 --- a/src/server/streamableHttp.ts +++ b/src/server/streamableHttp.ts @@ -89,7 +89,7 @@ export class StreamableHTTPServerTransport implements Transport { } else if (req.method === "DELETE") { await this.handleDeleteRequest(req, res); } else { - await this.handleUnsupportedRequest(req, res); + await this.handleUnsupportedRequest(res); } } @@ -97,8 +97,7 @@ export class StreamableHTTPServerTransport implements Transport { * Handles unsupported requests (GET, PUT, PATCH, etc.) * For now we support only POST and DELETE requests. Support for GET for SSE connections will be added later. */ - private async handleUnsupportedRequest(req: IncomingMessage, res: ServerResponse): Promise { - + private async handleUnsupportedRequest(res: ServerResponse): Promise { res.writeHead(405, { "Allow": "POST, DELETE" }).end(JSON.stringify({ @@ -119,8 +118,7 @@ export class StreamableHTTPServerTransport implements Transport { // Validate the Accept header const acceptHeader = req.headers.accept; // The client MUST include an Accept header, listing both application/json and text/event-stream as supported content types. - if (!acceptHeader || - !acceptHeader.includes("application/json") || !acceptHeader.includes("text/event-stream")) { + if (!acceptHeader?.includes("application/json") || !acceptHeader.includes("text/event-stream")) { res.writeHead(406).end(JSON.stringify({ jsonrpc: "2.0", error: { @@ -172,6 +170,17 @@ export class StreamableHTTPServerTransport implements Transport { msg => 'method' in msg && msg.method === 'initialize' ); if (isInitializationRequest) { + if (messages.length > 1) { + res.writeHead(400).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32600, + message: "Invalid Request: Only one initialization request is allowed" + }, + id: null + })); + return; + } const headers: Record = {}; if (this._sessionId !== undefined) { @@ -280,7 +289,18 @@ export class StreamableHTTPServerTransport implements Transport { id: null })); return false; - } else if ((Array.isArray(sessionId) ? sessionId[0] : sessionId) !== this._sessionId) { + } else if (Array.isArray(sessionId)) { + res.writeHead(400).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Bad Request: Mcp-Session-Id header must be a single value" + }, + id: null + })); + return false; + } + else if (sessionId !== this._sessionId) { // Reject requests with invalid session ID with 404 Not Found res.writeHead(404).end(JSON.stringify({ jsonrpc: "2.0", @@ -331,7 +351,12 @@ export class StreamableHTTPServerTransport implements Transport { // This is a response to the original request, we can close the stream // after sending all related responses this._sseResponseMapping.delete(relatedRequestId); - sseResponse.end(); + + // Only close the connection if it's not needed by other requests + const canCloseConnection = ![...this._sseResponseMapping.entries()].some(([id, res]) => res === sseResponse && id !== relatedRequestId); + if (canCloseConnection) { + sseResponse.end(); + } } } } diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 0a760fec..b072e578 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -111,10 +111,10 @@ export type RequestHandlerExtra Promise; /** From 88dc565f372fa8fbcdd7c692708272b9f6639429 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Mon, 7 Apr 2025 17:29:49 +0100 Subject: [PATCH 17/17] pass sessionIdGenerator instead of sessionId --- src/server/streamableHttp.test.ts | 490 ++++++++++++++++++++++++++---- src/server/streamableHttp.ts | 61 ++-- 2 files changed, 473 insertions(+), 78 deletions(-) diff --git a/src/server/streamableHttp.test.ts b/src/server/streamableHttp.test.ts index 6062a6c9..aff9e511 100644 --- a/src/server/streamableHttp.test.ts +++ b/src/server/streamableHttp.test.ts @@ -43,7 +43,7 @@ describe("StreamableHTTPServerTransport", () => { beforeEach(() => { transport = new StreamableHTTPServerTransport({ - sessionId: randomUUID(), + sessionIdGenerator: () => randomUUID(), }); mockResponse = createMockResponse(); mockRequest = JSON.stringify({ @@ -59,13 +59,7 @@ describe("StreamableHTTPServerTransport", () => { }); describe("Session Management", () => { - it("should store a valid session ID", () => { - expect(transport.sessionId).toBeTruthy(); - expect(typeof transport.sessionId).toBe("string"); - }); - - it("should include session ID in response headers", async () => { - // Use POST with initialize method to avoid session ID requirement + it("should generate session ID during initialization", async () => { const initializeMessage: JSONRPCMessage = { jsonrpc: "2.0", method: "initialize", @@ -85,8 +79,13 @@ describe("StreamableHTTPServerTransport", () => { body: JSON.stringify(initializeMessage), }); + expect(transport.sessionId).toBeUndefined(); + expect(transport["_initialized"]).toBe(false); + await transport.handleRequest(req, mockResponse); + expect(transport.sessionId).toBeDefined(); + expect(transport["_initialized"]).toBe(true); expect(mockResponse.writeHead).toHaveBeenCalledWith( 200, expect.objectContaining({ @@ -95,7 +94,122 @@ describe("StreamableHTTPServerTransport", () => { ); }); + it("should reject second initialization request", async () => { + // First initialize + const initMessage1: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + const req1 = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initMessage1), + }); + + await transport.handleRequest(req1, mockResponse); + expect(transport["_initialized"]).toBe(true); + + // Reset mock for second request + mockResponse.writeHead.mockClear(); + mockResponse.end.mockClear(); + + // Try second initialize + const initMessage2: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-2", + }; + + const req2 = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initMessage2), + }); + + await transport.handleRequest(req2, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(400); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"message":"Invalid Request: Server already initialized"')); + }); + + it("should reject batch initialize request", async () => { + const batchInitialize: JSONRPCMessage[] = [ + { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }, + { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client-2", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-2", + } + ]; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(batchInitialize), + }); + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(400); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"message":"Invalid Request: Only one initialization request is allowed"')); + }); + it("should reject invalid session ID", async () => { + // First initialize the transport + const initMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + const initReq = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initMessage), + }); + + await transport.handleRequest(initReq, mockResponse); + mockResponse.writeHead.mockClear(); + + // Now try with an invalid session ID const req = createMockRequest({ method: "POST", headers: { @@ -109,12 +223,35 @@ describe("StreamableHTTPServerTransport", () => { await transport.handleRequest(req, mockResponse); expect(mockResponse.writeHead).toHaveBeenCalledWith(404); - // check if the error response is a valid JSON-RPC error format expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); - expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"error"')); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"message":"Session not found"')); }); it("should reject non-initialization requests without session ID with 400 Bad Request", async () => { + // First initialize the transport + const initMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + const initReq = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initMessage), + }); + + await transport.handleRequest(initReq, mockResponse); + mockResponse.writeHead.mockClear(); + + // Now try without session ID const req = createMockRequest({ method: "POST", headers: { @@ -131,17 +268,103 @@ describe("StreamableHTTPServerTransport", () => { expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"message":"Bad Request: Mcp-Session-Id header is required"')); }); + + it("should reject requests to uninitialized server", async () => { + // Create a new transport that hasn't been initialized + const uninitializedTransport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + }); + + const req = createMockRequest({ + method: "POST", + headers: { + "accept": "application/json, text/event-stream", + "content-type": "application/json", + "mcp-session-id": "any-session-id", + }, + body: mockRequest + }); + + await uninitializedTransport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(400); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"message":"Bad Request: Server not initialized"')); + }); + + it("should reject session ID as array", async () => { + // First initialize the transport + const initMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + const initReq = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initMessage), + }); + + await transport.handleRequest(initReq, mockResponse); + mockResponse.writeHead.mockClear(); + + // Now try with an array session ID + const req = createMockRequest({ + method: "POST", + headers: { + "mcp-session-id": ["session1", "session2"], + "accept": "application/json, text/event-stream", + "content-type": "application/json", + }, + body: mockRequest, + }); + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(400); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"message":"Bad Request: Mcp-Session-Id header must be a single value"')); + }); }); - describe("Stateless Mode", () => { - let statelessTransport: StreamableHTTPServerTransport; + describe("Mode without state management", () => { + let transportWithoutState: StreamableHTTPServerTransport; let mockResponse: jest.Mocked; - beforeEach(() => { - statelessTransport = new StreamableHTTPServerTransport({ sessionId: undefined }); + beforeEach(async () => { + transportWithoutState = new StreamableHTTPServerTransport({ sessionIdGenerator: () => undefined }); mockResponse = createMockResponse(); + + // Initialize the transport for each test + const initMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + const initReq = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initMessage), + }); + + await transportWithoutState.handleRequest(initReq, mockResponse); + mockResponse.writeHead.mockClear(); }); - it("should not include session ID in response headers when in stateless mode", async () => { + it("should not include session ID in response headers when in mode without state management", async () => { // Use a non-initialization request const message: JSONRPCMessage = { jsonrpc: "2.0", @@ -159,7 +382,7 @@ describe("StreamableHTTPServerTransport", () => { body: JSON.stringify(message), }); - await statelessTransport.handleRequest(req, mockResponse); + await transportWithoutState.handleRequest(req, mockResponse); expect(mockResponse.writeHead).toHaveBeenCalled(); // Extract the headers from writeHead call @@ -167,13 +390,13 @@ describe("StreamableHTTPServerTransport", () => { expect(headers).not.toHaveProperty("mcp-session-id"); }); - it("should not validate session ID in stateless mode", async () => { + it("should not validate session ID in mode without state management", async () => { const req = createMockRequest({ method: "POST", headers: { "content-type": "application/json", "accept": "application/json, text/event-stream", - "mcp-session-id": "invalid-session-id", // This would cause a 404 in stateful mode + "mcp-session-id": "invalid-session-id", // This would cause a 404 in mode with state management }, body: JSON.stringify({ jsonrpc: "2.0", @@ -183,7 +406,7 @@ describe("StreamableHTTPServerTransport", () => { }), }); - await statelessTransport.handleRequest(req, mockResponse); + await transportWithoutState.handleRequest(req, mockResponse); // Should still get 200 OK, not 404 Not Found expect(mockResponse.writeHead).toHaveBeenCalledWith( @@ -194,7 +417,7 @@ describe("StreamableHTTPServerTransport", () => { ); }); - it("should handle POST requests without session validation in stateless mode", async () => { + it("should handle POST requests without session validation in mode without state management", async () => { const message: JSONRPCMessage = { jsonrpc: "2.0", method: "test", @@ -207,21 +430,21 @@ describe("StreamableHTTPServerTransport", () => { headers: { "content-type": "application/json", "accept": "application/json, text/event-stream", - "mcp-session-id": "non-existent-session-id", // This would be rejected in stateful mode + "mcp-session-id": "non-existent-session-id", // This would be rejected in mode with state management }, body: JSON.stringify(message), }); const onMessageMock = jest.fn(); - statelessTransport.onmessage = onMessageMock; + transportWithoutState.onmessage = onMessageMock; - await statelessTransport.handleRequest(req, mockResponse); + await transportWithoutState.handleRequest(req, mockResponse); // Message should be processed despite invalid session ID expect(onMessageMock).toHaveBeenCalledWith(message); }); - it("should work with a mix of requests with and without session IDs in stateless mode", async () => { + it("should work with a mix of requests with and without session IDs in mode without state management", async () => { // First request without session ID const req1 = createMockRequest({ method: "POST", @@ -237,7 +460,7 @@ describe("StreamableHTTPServerTransport", () => { }) }); - await statelessTransport.handleRequest(req1, mockResponse); + await transportWithoutState.handleRequest(req1, mockResponse); expect(mockResponse.writeHead).toHaveBeenCalledWith( 200, expect.objectContaining({ @@ -248,7 +471,7 @@ describe("StreamableHTTPServerTransport", () => { // Reset mock for second request mockResponse.writeHead.mockClear(); - // Second request with a session ID (which would be invalid in stateful mode) + // Second request with a session ID (which would be invalid in mode with state management) const req2 = createMockRequest({ method: "POST", headers: { @@ -264,7 +487,7 @@ describe("StreamableHTTPServerTransport", () => { }) }); - await statelessTransport.handleRequest(req2, mockResponse); + await transportWithoutState.handleRequest(req2, mockResponse); // Should still succeed expect(mockResponse.writeHead).toHaveBeenCalledWith( @@ -275,8 +498,10 @@ describe("StreamableHTTPServerTransport", () => { ); }); - it("should handle initialization requests properly in stateful mode", async () => { - // Initialize message that would typically be sent during initialization + it("should handle initialization in mode without state management", async () => { + const transportWithoutState = new StreamableHTTPServerTransport({ sessionIdGenerator: () => undefined }); + + // Initialize message const initializeMessage: JSONRPCMessage = { jsonrpc: "2.0", method: "initialize", @@ -287,8 +512,10 @@ describe("StreamableHTTPServerTransport", () => { id: "init-1", }; - // Test stateful transport (default) - const statefulReq = createMockRequest({ + expect(transportWithoutState.sessionId).toBeUndefined(); + expect(transportWithoutState["_initialized"]).toBe(false); + + const req = createMockRequest({ method: "POST", headers: { "content-type": "application/json", @@ -297,20 +524,24 @@ describe("StreamableHTTPServerTransport", () => { body: JSON.stringify(initializeMessage), }); - await transport.handleRequest(statefulReq, mockResponse); + const newResponse = createMockResponse(); + await transportWithoutState.handleRequest(req, newResponse); - // In stateful mode, session ID should be included in the response header - expect(mockResponse.writeHead).toHaveBeenCalledWith( - 200, - expect.objectContaining({ - "mcp-session-id": transport.sessionId, - }) - ); + // After initialization, the sessionId should still be undefined + expect(transportWithoutState.sessionId).toBeUndefined(); + expect(transportWithoutState["_initialized"]).toBe(true); + + // Headers should NOT include session ID in mode without state management + const headers = newResponse.writeHead.mock.calls[0][1]; + expect(headers).not.toHaveProperty("mcp-session-id"); }); + }); - it("should handle initialization requests properly in stateless mode", async () => { - // Initialize message that would typically be sent during initialization - const initializeMessage: JSONRPCMessage = { + describe("Request Handling", () => { + // Initialize the transport before tests that need initialization + beforeEach(async () => { + // For tests that need initialization, initialize here + const initMessage: JSONRPCMessage = { jsonrpc: "2.0", method: "initialize", params: { @@ -320,26 +551,19 @@ describe("StreamableHTTPServerTransport", () => { id: "init-1", }; - // Test stateless transport - const statelessReq = createMockRequest({ + const initReq = createMockRequest({ method: "POST", headers: { "content-type": "application/json", "accept": "application/json, text/event-stream", }, - body: JSON.stringify(initializeMessage), + body: JSON.stringify(initMessage), }); - await statelessTransport.handleRequest(statelessReq, mockResponse); - - // In stateless mode, session ID should also be included for initialize responses - const headers = mockResponse.writeHead.mock.calls[0][1]; - expect(headers).not.toHaveProperty("mcp-session-id"); - + await transport.handleRequest(initReq, mockResponse); + mockResponse.writeHead.mockClear(); }); - }); - describe("Request Handling", () => { it("should reject GET requests for SSE with 405 Method Not Allowed", async () => { const req = createMockRequest({ method: "GET", @@ -361,7 +585,7 @@ describe("StreamableHTTPServerTransport", () => { it("should reject POST requests without proper Accept header", async () => { const message: JSONRPCMessage = { jsonrpc: "2.0", - method: "initialize", // Use initialize to bypass session ID check + method: "test", params: {}, id: 1, }; @@ -370,6 +594,7 @@ describe("StreamableHTTPServerTransport", () => { method: "POST", headers: { "content-type": "application/json", + "mcp-session-id": transport.sessionId, }, body: JSON.stringify(message), }); @@ -382,7 +607,7 @@ describe("StreamableHTTPServerTransport", () => { it("should properly handle JSON-RPC request messages in POST requests", async () => { const message: JSONRPCMessage = { jsonrpc: "2.0", - method: "initialize", // Use initialize to bypass session ID check + method: "test", params: {}, id: 1, }; @@ -392,6 +617,7 @@ describe("StreamableHTTPServerTransport", () => { headers: { "content-type": "application/json", "accept": "application/json, text/event-stream", + "mcp-session-id": transport.sessionId, }, body: JSON.stringify(message), }); @@ -405,7 +631,7 @@ describe("StreamableHTTPServerTransport", () => { expect(mockResponse.writeHead).toHaveBeenCalledWith( 200, expect.objectContaining({ - "mcp-session-id": transport.sessionId, + "Content-Type": "text/event-stream", }) ); }); @@ -480,6 +706,7 @@ describe("StreamableHTTPServerTransport", () => { const onMessageMock = jest.fn(); transport.onmessage = onMessageMock; + mockResponse = createMockResponse(); // Create fresh mock await transport.handleRequest(req, mockResponse); // Should establish SSE connection @@ -512,6 +739,30 @@ describe("StreamableHTTPServerTransport", () => { }); it("should properly handle DELETE requests and close session", async () => { + // First initialize the transport + const initMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + const initReq = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initMessage), + }); + + await transport.handleRequest(initReq, mockResponse); + mockResponse.writeHead.mockClear(); + + // Now try DELETE with proper session ID const req = createMockRequest({ method: "DELETE", headers: { @@ -527,9 +778,75 @@ describe("StreamableHTTPServerTransport", () => { expect(mockResponse.writeHead).toHaveBeenCalledWith(200); expect(onCloseMock).toHaveBeenCalled(); }); + + it("should reject DELETE requests with invalid session ID", async () => { + // First initialize the transport + const initMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + const initReq = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initMessage), + }); + + await transport.handleRequest(initReq, mockResponse); + mockResponse.writeHead.mockClear(); + + // Now try DELETE with invalid session ID + const req = createMockRequest({ + method: "DELETE", + headers: { + "mcp-session-id": "invalid-session-id", + }, + }); + + const onCloseMock = jest.fn(); + transport.onclose = onCloseMock; + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(404); + expect(onCloseMock).not.toHaveBeenCalled(); + }); }); describe("SSE Response Handling", () => { + beforeEach(async () => { + // Initialize the transport + const initMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + const initReq = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initMessage), + }); + + await transport.handleRequest(initReq, mockResponse); + mockResponse.writeHead.mockClear(); + }); + it("should send response messages as SSE events", async () => { // Setup a POST request with JSON-RPC request that accepts SSE const requestMessage: JSONRPCMessage = { @@ -578,6 +895,9 @@ describe("StreamableHTTPServerTransport", () => { id: "test-req-id" }; + // Create fresh response for this test + mockResponse = createMockResponse(); + const req = createMockRequest({ method: "POST", headers: { @@ -617,6 +937,31 @@ describe("StreamableHTTPServerTransport", () => { }); describe("Message Targeting", () => { + beforeEach(async () => { + // Initialize the transport + const initMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + const initReq = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initMessage), + }); + + await transport.handleRequest(initReq, mockResponse); + mockResponse.writeHead.mockClear(); + }); + it("should send response messages to the connection that sent the request", async () => { // Create request with two separate connections const requestMessage1: JSONRPCMessage = { @@ -697,7 +1042,7 @@ describe("StreamableHTTPServerTransport", () => { }); describe("Error Handling", () => { - it("should handle invalid JSON data", async () => { + it("should return 400 error for invalid JSON data", async () => { const req = createMockRequest({ method: "POST", headers: { @@ -718,7 +1063,7 @@ describe("StreamableHTTPServerTransport", () => { expect(onErrorMock).toHaveBeenCalled(); }); - it("should handle invalid JSON-RPC messages", async () => { + it("should return 400 error for invalid JSON-RPC messages", async () => { const req = createMockRequest({ method: "POST", headers: { @@ -740,14 +1085,36 @@ describe("StreamableHTTPServerTransport", () => { }); describe("Handling Pre-Parsed Body", () => { - it("should accept pre-parsed request body", async () => { - const message: JSONRPCMessage = { + beforeEach(async () => { + // Initialize the transport + const initMessage: JSONRPCMessage = { jsonrpc: "2.0", method: "initialize", params: { clientInfo: { name: "test-client", version: "1.0" }, protocolVersion: "2025-03-26" }, + id: "init-1", + }; + + const initReq = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initMessage), + }); + + await transport.handleRequest(initReq, mockResponse); + mockResponse.writeHead.mockClear(); + }); + + it("should accept pre-parsed request body", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, id: "pre-parsed-test", }; @@ -757,6 +1124,7 @@ describe("StreamableHTTPServerTransport", () => { headers: { "content-type": "application/json", "accept": "application/json, text/event-stream", + "mcp-session-id": transport.sessionId, }, // No body provided here - it will be passed as parsedBody }); @@ -772,7 +1140,7 @@ describe("StreamableHTTPServerTransport", () => { expect(mockResponse.writeHead).toHaveBeenCalledWith( 200, expect.objectContaining({ - "mcp-session-id": transport.sessionId, + "Content-Type": "text/event-stream", }) ); }); diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts index 8ad821de..34b4fd95 100644 --- a/src/server/streamableHttp.ts +++ b/src/server/streamableHttp.ts @@ -11,11 +11,12 @@ const MAXIMUM_MESSAGE_SIZE = "4mb"; */ export interface StreamableHTTPServerTransportOptions { /** + * Function that generates a session ID for the transport. * The session ID SHOULD be globally unique and cryptographically secure (e.g., a securely generated UUID, a JWT, or a cryptographic hash) * - * When there is no sessionId, the transport will not perform session management. + * Return undefined to disable session management. */ - sessionId: string | undefined; + sessionIdGenerator: () => string | undefined; @@ -57,16 +58,18 @@ export interface StreamableHTTPServerTransportOptions { */ export class StreamableHTTPServerTransport implements Transport { // when sessionId is not set (undefined), it means the transport is in stateless mode - private _sessionId: string | undefined; + private sessionIdGenerator: () => string | undefined; private _started: boolean = false; private _sseResponseMapping: Map = new Map(); + private _initialized: boolean = false; + sessionId?: string | undefined; onclose?: () => void; onerror?: (error: Error) => void; onmessage?: (message: JSONRPCMessage) => void; constructor(options: StreamableHTTPServerTransportOptions) { - this._sessionId = options.sessionId; + this.sessionIdGenerator = options.sessionIdGenerator; } /** @@ -170,6 +173,17 @@ export class StreamableHTTPServerTransport implements Transport { msg => 'method' in msg && msg.method === 'initialize' ); if (isInitializationRequest) { + if (this._initialized) { + res.writeHead(400).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32600, + message: "Invalid Request: Server already initialized" + }, + id: null + })); + return; + } if (messages.length > 1) { res.writeHead(400).end(JSON.stringify({ jsonrpc: "2.0", @@ -181,10 +195,12 @@ export class StreamableHTTPServerTransport implements Transport { })); return; } + this.sessionId = this.sessionIdGenerator(); + this._initialized = true; const headers: Record = {}; - if (this._sessionId !== undefined) { - headers["mcp-session-id"] = this._sessionId; + if (this.sessionId !== undefined) { + headers["mcp-session-id"] = this.sessionId; } // Process initialization messages before responding @@ -198,7 +214,7 @@ export class StreamableHTTPServerTransport implements Transport { // If an Mcp-Session-Id is returned by the server during initialization, // clients using the Streamable HTTP transport MUST include it // in the Mcp-Session-Id header on all of their subsequent HTTP requests. - if (this._sessionId !== undefined && !isInitializationRequest && !this.validateSession(req, res)) { + if (!isInitializationRequest && !this.validateSession(req, res)) { return; } @@ -224,8 +240,8 @@ export class StreamableHTTPServerTransport implements Transport { }; // After initialization, always include the session ID if we have one - if (this._sessionId !== undefined) { - headers["mcp-session-id"] = this._sessionId; + if (this.sessionId !== undefined) { + headers["mcp-session-id"] = this.sessionId; } res.writeHead(200, headers); @@ -272,10 +288,27 @@ export class StreamableHTTPServerTransport implements Transport { } /** - * Validates session ID for non-initialization requests when session management is enabled + * Validates session ID for non-initialization requests * Returns true if the session is valid, false otherwise */ private validateSession(req: IncomingMessage, res: ServerResponse): boolean { + if (!this._initialized) { + // If the server has not been initialized yet, reject all requests + res.writeHead(400).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Bad Request: Server not initialized" + }, + id: null + })); + return false; + } + if (this.sessionId === undefined) { + // If the session ID is not set, the session management is disabled + // and we don't need to validate the session ID + return true; + } const sessionId = req.headers["mcp-session-id"]; if (!sessionId) { @@ -300,7 +333,7 @@ export class StreamableHTTPServerTransport implements Transport { })); return false; } - else if (sessionId !== this._sessionId) { + else if (sessionId !== this.sessionId) { // Reject requests with invalid session ID with 404 Not Found res.writeHead(404).end(JSON.stringify({ jsonrpc: "2.0", @@ -361,10 +394,4 @@ export class StreamableHTTPServerTransport implements Transport { } } - /** - * Returns the session ID for this transport - */ - get sessionId(): string | undefined { - return this._sessionId; - } } \ No newline at end of file