From d9dea6c979c42909e5195c1f0bbb3250efcf95a5 Mon Sep 17 00:00:00 2001 From: Geng Yan Date: Fri, 28 Mar 2025 09:00:07 +0800 Subject: [PATCH 1/5] =?UTF-8?q?=E2=9C=A8=20feat:=20StreamableHTTPServerTra?= =?UTF-8?q?nsport=20implement?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/server/streamable-http.test.ts | 427 +++++++++++++++++++++++++++++ src/server/streamable-http.ts | 388 ++++++++++++++++++++++++++ 2 files changed, 815 insertions(+) create mode 100644 src/server/streamable-http.test.ts create mode 100644 src/server/streamable-http.ts diff --git a/src/server/streamable-http.test.ts b/src/server/streamable-http.test.ts new file mode 100644 index 00000000..8dd5dea7 --- /dev/null +++ b/src/server/streamable-http.test.ts @@ -0,0 +1,427 @@ +import { IncomingMessage, ServerResponse } from "node:http"; +import { StreamableHTTPServerTransport } from "./streamable-http.js"; +import { JSONRPCMessage } from "../types.js"; +import { Readable } from "node:stream"; + +// Mock IncomingMessage +function createMockRequest(options: { + method: string; + headers: Record; + body?: string; +}): IncomingMessage { + const readable = new Readable(); + readable._read = () => {}; + if (options.body) { + readable.push(options.body); + readable.push(null); + } + + return Object.assign(readable, { + method: options.method, + headers: options.headers, + }) as IncomingMessage; +} + +// Mock ServerResponse +function createMockResponse(): jest.Mocked { + const response = { + writeHead: jest.fn().mockReturnThis(), + write: jest.fn().mockReturnThis(), + end: jest.fn().mockReturnThis(), + on: jest.fn().mockReturnThis(), + emit: jest.fn().mockReturnThis(), + getHeader: jest.fn(), + setHeader: jest.fn(), + } as unknown as jest.Mocked; + return response; +} + +describe("StreamableHTTPServerTransport", () => { + const endpoint = "/mcp"; + let transport: StreamableHTTPServerTransport; + let mockResponse: jest.Mocked; + + beforeEach(() => { + transport = new StreamableHTTPServerTransport(endpoint); + mockResponse = createMockResponse(); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + describe("Session Management", () => { + it("should generate a valid session ID", () => { + expect(transport.sessionId).toBeTruthy(); + expect(typeof transport.sessionId).toBe("string"); + }); + + it("should include session ID in response headers", async () => { + const req = createMockRequest({ + method: "GET", + headers: { + accept: "text/event-stream" + }, + }); + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "Mcp-Session-Id": transport.sessionId, + }) + ); + }); + + it("should reject invalid session ID", async () => { + const req = createMockRequest({ + method: "GET", + headers: { + "mcp-session-id": "invalid-session-id", + }, + }); + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(404); + // check if the error response is a valid JSON-RPC error format + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"error"')); + }); + }); + + describe("Request Handling", () => { + it("should reject GET requests without Accept: text/event-stream header", async () => { + const req = createMockRequest({ + method: "GET", + headers: {}, + }); + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(406); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); + }); + + it("should properly handle GET requests with Accept header and establish SSE connection", async () => { + const req = createMockRequest({ + method: "GET", + headers: { + accept: "text/event-stream", + }, + }); + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }) + ); + }); + + it("should reject POST requests without proper Accept header", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: 1, + }; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + }, + body: JSON.stringify(message), + }); + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(406); + }); + + it("should properly handle JSON-RPC request messages in POST requests", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: 1, + }; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "text/event-stream", + }, + body: JSON.stringify(message), + }); + + const onMessageMock = jest.fn(); + transport.onmessage = onMessageMock; + + await transport.handleRequest(req, mockResponse); + + expect(onMessageMock).toHaveBeenCalledWith(message); + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "Content-Type": "text/event-stream", + }) + ); + }); + + it("should properly handle JSON-RPC notification or response messages in POST requests", async () => { + const notification: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + }; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(notification), + }); + + const onMessageMock = jest.fn(); + transport.onmessage = onMessageMock; + + await transport.handleRequest(req, mockResponse); + + expect(onMessageMock).toHaveBeenCalledWith(notification); + expect(mockResponse.writeHead).toHaveBeenCalledWith(202); + }); + + it("should handle batch messages properly", async () => { + const batchMessages: JSONRPCMessage[] = [ + { jsonrpc: "2.0", method: "test1", params: {} }, + { jsonrpc: "2.0", method: "test2", params: {} }, + ]; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json", + }, + body: JSON.stringify(batchMessages), + }); + + const onMessageMock = jest.fn(); + transport.onmessage = onMessageMock; + + await transport.handleRequest(req, mockResponse); + + expect(onMessageMock).toHaveBeenCalledTimes(2); + expect(mockResponse.writeHead).toHaveBeenCalledWith(202); + }); + + it("should reject unsupported Content-Type", async () => { + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "text/plain", + "accept": "application/json", + }, + body: "test", + }); + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(415); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); + }); + + it("should properly handle DELETE requests and close session", async () => { + const req = createMockRequest({ + method: "DELETE", + headers: {}, + }); + + const onCloseMock = jest.fn(); + transport.onclose = onCloseMock; + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(200); + expect(onCloseMock).toHaveBeenCalled(); + }); + }); + + describe("Message Replay", () => { + it("should replay messages after specified Last-Event-ID", async () => { + // Establish first connection with Accept header + const req1 = createMockRequest({ + method: "GET", + headers: { + "accept": "text/event-stream" + }, + }); + await transport.handleRequest(req1, mockResponse); + + // Send a message to first connection + const message1: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test1", + params: {}, + id: 1 + }; + + await transport.send(message1); + + // Get message ID (captured from write call) + const writeCall = mockResponse.write.mock.calls[0][0] as string; + const idMatch = writeCall.match(/id: ([a-f0-9-]+)/); + if (!idMatch) { + throw new Error("Message ID not found in write call"); + } + const lastEventId = idMatch[1]; + + // Create a second connection with last-event-id + const mockResponse2 = createMockResponse(); + const req2 = createMockRequest({ + method: "GET", + headers: { + "accept": "text/event-stream", + "last-event-id": lastEventId, + }, + }); + + await transport.handleRequest(req2, mockResponse2); + + // Send a second message + const message2: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test2", + params: {}, + id: 2 + }; + + await transport.send(message2); + + // Verify the second message was received by both connections + expect(mockResponse.write).toHaveBeenCalledWith( + expect.stringContaining(JSON.stringify(message1)) + ); + expect(mockResponse2.write).toHaveBeenCalledWith( + expect.stringContaining(JSON.stringify(message2)) + ); + }); + }); + + describe("Message Targeting", () => { + it("should send response messages to the connection that sent the request", async () => { + // Create two connections + const mockResponse1 = createMockResponse(); + const req1 = createMockRequest({ + method: "GET", + headers: { + "accept": "text/event-stream", + }, + }); + await transport.handleRequest(req1, mockResponse1); + + const mockResponse2 = createMockResponse(); + const req2 = createMockRequest({ + method: "GET", + headers: { + "accept": "text/event-stream", + }, + }); + await transport.handleRequest(req2, mockResponse2); + + // Send a request through the first connection + const requestMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: "test-id", + }; + + const reqPost = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "text/event-stream", + }, + body: JSON.stringify(requestMessage), + }); + + await transport.handleRequest(reqPost, mockResponse1); + + // Send a response with matching ID + const responseMessage: JSONRPCMessage = { + jsonrpc: "2.0", + result: { success: true }, + id: "test-id", + }; + + await transport.send(responseMessage); + + // Verify response was sent to the right connection + expect(mockResponse1.write).toHaveBeenCalledWith( + expect.stringContaining(JSON.stringify(responseMessage)) + ); + + // Check if write was called with this exact message on the second connection + const writeCallsOnSecondConn = mockResponse2.write.mock.calls.filter(call => + typeof call[0] === 'string' && call[0].includes(JSON.stringify(responseMessage)) + ); + + // Verify the response wasn't broadcast to all connections + expect(writeCallsOnSecondConn.length).toBe(0); + }); + }); + + describe("Error Handling", () => { + it("should handle invalid JSON data", async () => { + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json", + }, + body: "invalid json", + }); + + const onErrorMock = jest.fn(); + transport.onerror = onErrorMock; + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(400); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"code":-32700')); + expect(onErrorMock).toHaveBeenCalled(); + }); + + it("should handle invalid JSON-RPC messages", async () => { + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json", + }, + body: JSON.stringify({ invalid: "message" }), + }); + + const onErrorMock = jest.fn(); + transport.onerror = onErrorMock; + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(400); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); + expect(onErrorMock).toHaveBeenCalled(); + }); + }); +}); \ No newline at end of file diff --git a/src/server/streamable-http.ts b/src/server/streamable-http.ts new file mode 100644 index 00000000..2567240c --- /dev/null +++ b/src/server/streamable-http.ts @@ -0,0 +1,388 @@ +import { randomUUID } from "node:crypto"; +import { IncomingMessage, ServerResponse } from "node:http"; +import { Transport } from "../shared/transport.js"; +import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; +import getRawBody from "raw-body"; +import contentType from "content-type"; + +const MAXIMUM_MESSAGE_SIZE = "4mb"; + +interface StreamConnection { + response: ServerResponse; + lastEventId?: string; + messages: Array<{ + id: string; + message: JSONRPCMessage; + }>; + // mark this connection as a response to a specific request + requestId?: string | null; +} + +/** + * Server transport for Streamable HTTP: this implements the MCP Streamable HTTP transport specification. + * It supports both SSE streaming and direct HTTP responses, with session management and message resumability. + */ +export class StreamableHTTPServerTransport implements Transport { + private _connections: Map = new Map(); + private _sessionId: string; + private _messageHistory: Map = new Map(); + private _started: boolean = false; + private _requestConnections: Map = new Map(); // request ID to connection ID mapping + + onclose?: () => void; + onerror?: (error: Error) => void; + onmessage?: (message: JSONRPCMessage) => void; + + constructor(private _endpoint: string) { + this._sessionId = randomUUID(); + } + + /** + * Starts the transport. This is required by the Transport interface but is a no-op + * for the Streamable HTTP transport as connections are managed per-request. + */ + async start(): Promise { + if (this._started) { + throw new Error("Transport already started"); + } + this._started = true; + } + + /** + * Handles an incoming HTTP request, whether GET or POST + */ + async handleRequest(req: IncomingMessage, res: ServerResponse): Promise { + // validate the session ID + const sessionId = req.headers["mcp-session-id"]; + if (sessionId && (Array.isArray(sessionId) ? sessionId[0] : sessionId) !== this._sessionId) { + res.writeHead(404).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32001, + message: "Session not found" + }, + id: null + })); + return; + } + + if (req.method === "GET") { + await this.handleGetRequest(req, res); + } else if (req.method === "POST") { + await this.handlePostRequest(req, res); + } else if (req.method === "DELETE") { + await this.handleDeleteRequest(req, res); + } else { + res.writeHead(405).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Method not allowed" + }, + id: null + })); + } + } + + /** + * Handles GET requests to establish SSE connections + */ + private async handleGetRequest(req: IncomingMessage, res: ServerResponse): Promise { + // validate the Accept header + const acceptHeader = req.headers.accept; + if (!acceptHeader || !acceptHeader.includes("text/event-stream")) { + res.writeHead(406).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Not Acceptable: Client must accept text/event-stream" + }, + id: null + })); + return; + } + + const connectionId = randomUUID(); + const lastEventId = req.headers["last-event-id"]; + const lastEventIdStr = Array.isArray(lastEventId) ? lastEventId[0] : lastEventId; + + res.writeHead(200, { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + "Mcp-Session-Id": this._sessionId, + }); + + const connection: StreamConnection = { + response: res, + lastEventId: lastEventIdStr, + messages: [], + }; + + this._connections.set(connectionId, connection); + + // if there is a Last-Event-ID, replay messages on this connection + if (lastEventIdStr) { + this.replayMessages(connectionId, lastEventIdStr); + } + + res.on("close", () => { + this._connections.delete(connectionId); + // remove all request mappings associated with this connection + for (const [reqId, connId] of this._requestConnections.entries()) { + if (connId === connectionId) { + this._requestConnections.delete(reqId); + } + } + if (this._connections.size === 0) { + this.onclose?.(); + } + }); + } + + /** + * Handles POST requests containing JSON-RPC messages + */ + private async handlePostRequest(req: IncomingMessage, res: ServerResponse): Promise { + try { + // validate the Accept header + const acceptHeader = req.headers.accept; + if (!acceptHeader || + (!acceptHeader.includes("application/json") && !acceptHeader.includes("text/event-stream"))) { + res.writeHead(406).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Not Acceptable: Client must accept application/json and/or text/event-stream" + }, + id: null + })); + return; + } + + const ct = req.headers["content-type"]; + if (!ct || !ct.includes("application/json")) { + res.writeHead(415).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Unsupported Media Type: Content-Type must be application/json" + }, + id: null + })); + return; + } + + const parsedCt = contentType.parse(ct); + const body = await getRawBody(req, { + limit: MAXIMUM_MESSAGE_SIZE, + encoding: parsedCt.parameters.charset ?? "utf-8", + }); + + const rawMessage = JSON.parse(body.toString()); + let messages: JSONRPCMessage[]; + + // handle batch and single messages + if (Array.isArray(rawMessage)) { + messages = rawMessage.map(msg => JSONRPCMessageSchema.parse(msg)); + } else { + messages = [JSONRPCMessageSchema.parse(rawMessage)]; + } + + // check if it contains requests + const hasRequests = messages.some(msg => 'method' in msg && 'id' in msg); + const hasOnlyNotificationsOrResponses = messages.every(msg => + ('method' in msg && !('id' in msg)) || ('result' in msg || 'error' in msg)); + + if (hasOnlyNotificationsOrResponses) { + // if it only contains notifications or responses, return 202 + res.writeHead(202).end(); + + // handle each message + for (const message of messages) { + this.onmessage?.(message); + } + } else if (hasRequests) { + // if it contains requests, you can choose to return an SSE stream or a JSON response + const useSSE = acceptHeader.includes("text/event-stream"); + + if (useSSE) { + res.writeHead(200, { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + "Mcp-Session-Id": this._sessionId, + }); + + const connectionId = randomUUID(); + const connection: StreamConnection = { + response: res, + messages: [], + }; + + this._connections.set(connectionId, connection); + + // map each request to a connection ID + for (const message of messages) { + if ('method' in message && 'id' in message) { + this._requestConnections.set(String(message.id), connectionId); + } + this.onmessage?.(message); + } + + res.on("close", () => { + this._connections.delete(connectionId); + // remove all request mappings associated with this connection + for (const [reqId, connId] of this._requestConnections.entries()) { + if (connId === connectionId) { + this._requestConnections.delete(reqId); + } + } + if (this._connections.size === 0) { + this.onclose?.(); + } + }); + } else { + // use direct JSON response + res.writeHead(200, { + "Content-Type": "application/json", + "Mcp-Session-Id": this._sessionId, + }); + + // handle each message + for (const message of messages) { + this.onmessage?.(message); + } + + res.end(); + } + } + } catch (error) { + // return JSON-RPC formatted error + res.writeHead(400).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32700, + message: "Parse error", + data: String(error) + }, + id: null + })); + this.onerror?.(error as Error); + } + } + + /** + * Handles DELETE requests to terminate sessions + */ + private async handleDeleteRequest(req: IncomingMessage, res: ServerResponse): Promise { + await this.close(); + res.writeHead(200).end(); + } + + /** + * Replays messages after the specified event ID for a specific connection + */ + private replayMessages(connectionId: string, lastEventId: string): void { + if (!lastEventId) return; + + // only replay messages that should be sent on this connection + const messages = Array.from(this._messageHistory.entries()) + .filter(([id, { connectionId: msgConnId }]) => + id > lastEventId && + (!msgConnId || msgConnId === connectionId)) // only replay messages that are not specified to a connection or specified to the current connection + .sort(([a], [b]) => a.localeCompare(b)); + + const connection = this._connections.get(connectionId); + if (!connection) return; + + for (const [id, { message }] of messages) { + connection.response.write( + `id: ${id}\nevent: message\ndata: ${JSON.stringify(message)}\n\n` + ); + } + } + + async close(): Promise { + for (const connection of this._connections.values()) { + connection.response.end(); + } + this._connections.clear(); + this._messageHistory.clear(); + this._requestConnections.clear(); + this.onclose?.(); + } + + async send(message: JSONRPCMessage): Promise { + if (this._connections.size === 0) { + throw new Error("No active connections"); + } + + let targetConnectionId = ""; + + // if it is a response, find the corresponding request connection + if ('id' in message && ('result' in message || 'error' in message)) { + const connId = this._requestConnections.get(String(message.id)); + + // if the corresponding connection is not found, the connection may be disconnected + if (!connId || !this._connections.has(connId)) { + // select an available connection + const firstConnId = this._connections.keys().next().value; + if (firstConnId) { + targetConnectionId = firstConnId; + } else { + throw new Error("No available connections"); + } + } else { + targetConnectionId = connId; + } + } else { + // for other messages, select an available connection + const firstConnId = this._connections.keys().next().value; + if (firstConnId) { + targetConnectionId = firstConnId; + } else { + throw new Error("No available connections"); + } + } + + const messageId = randomUUID(); + this._messageHistory.set(messageId, { + message, + connectionId: targetConnectionId + }); + + // keep the message history in a reasonable range + if (this._messageHistory.size > 1000) { + const oldestKey = Array.from(this._messageHistory.keys())[0]; + this._messageHistory.delete(oldestKey); + } + + // send the message to all active connections + for (const [connId, connection] of this._connections.entries()) { + // if it is a response message, only send to the target connection + if ('id' in message && ('result' in message || 'error' in message)) { + if (connId === targetConnectionId) { + connection.response.write( + `id: ${messageId}\nevent: message\ndata: ${JSON.stringify(message)}\n\n` + ); + } + } else { + // for other messages, send to all connections + connection.response.write( + `id: ${messageId}\nevent: message\ndata: ${JSON.stringify(message)}\n\n` + ); + } + } + } + + /** + * Returns the session ID for this transport + */ + get sessionId(): string { + return this._sessionId; + } +} \ No newline at end of file From 9eea6b225429938b802a5c53bf544001652a6f89 Mon Sep 17 00:00:00 2001 From: Geng Yan Date: Tue, 1 Apr 2025 10:58:37 +0800 Subject: [PATCH 2/5] fix: Use lowercase header names --- src/server/streamable-http.test.ts | 2 +- src/server/streamable-http.ts | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/server/streamable-http.test.ts b/src/server/streamable-http.test.ts index 8dd5dea7..fa5140c0 100644 --- a/src/server/streamable-http.test.ts +++ b/src/server/streamable-http.test.ts @@ -69,7 +69,7 @@ describe("StreamableHTTPServerTransport", () => { expect(mockResponse.writeHead).toHaveBeenCalledWith( 200, expect.objectContaining({ - "Mcp-Session-Id": transport.sessionId, + "mcp-session-id": transport.sessionId, }) ); }); diff --git a/src/server/streamable-http.ts b/src/server/streamable-http.ts index 2567240c..bf0096d1 100644 --- a/src/server/streamable-http.ts +++ b/src/server/streamable-http.ts @@ -113,7 +113,7 @@ export class StreamableHTTPServerTransport implements Transport { "Content-Type": "text/event-stream", "Cache-Control": "no-cache", Connection: "keep-alive", - "Mcp-Session-Id": this._sessionId, + "mcp-session-id": this._sessionId, }); const connection: StreamConnection = { @@ -214,7 +214,7 @@ export class StreamableHTTPServerTransport implements Transport { "Content-Type": "text/event-stream", "Cache-Control": "no-cache", Connection: "keep-alive", - "Mcp-Session-Id": this._sessionId, + "mcp-session-id": this._sessionId, }); const connectionId = randomUUID(); @@ -249,7 +249,7 @@ export class StreamableHTTPServerTransport implements Transport { // use direct JSON response res.writeHead(200, { "Content-Type": "application/json", - "Mcp-Session-Id": this._sessionId, + "mcp-session-id": this._sessionId, }); // handle each message From ed66f8a5d3196ae38965656601106c4aa7f22643 Mon Sep 17 00:00:00 2001 From: Geng Yan Date: Wed, 2 Apr 2025 08:47:35 +0800 Subject: [PATCH 3/5] =?UTF-8?q?=E2=9C=A8=20feat:=20try=20add=20stateless?= =?UTF-8?q?=20mod?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/server/streamable-http.test.ts | 268 ++++++++++++++++++++++++++++- src/server/streamable-http.ts | 131 +++++++++++--- 2 files changed, 368 insertions(+), 31 deletions(-) diff --git a/src/server/streamable-http.test.ts b/src/server/streamable-http.test.ts index fa5140c0..960d3800 100644 --- a/src/server/streamable-http.test.ts +++ b/src/server/streamable-http.test.ts @@ -57,11 +57,24 @@ describe("StreamableHTTPServerTransport", () => { }); it("should include session ID in response headers", async () => { + // Use POST with initialize method to avoid session ID requirement + const initializeMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + const req = createMockRequest({ - method: "GET", + method: "POST", headers: { - accept: "text/event-stream" + "content-type": "application/json", + "accept": "application/json", }, + body: JSON.stringify(initializeMessage), }); await transport.handleRequest(req, mockResponse); @@ -79,6 +92,7 @@ describe("StreamableHTTPServerTransport", () => { method: "GET", headers: { "mcp-session-id": "invalid-session-id", + "accept": "text/event-stream" }, }); @@ -89,13 +103,241 @@ describe("StreamableHTTPServerTransport", () => { expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"error"')); }); + + it("should reject non-initialization requests without session ID with 400 Bad Request", async () => { + const req = createMockRequest({ + method: "GET", + headers: { + accept: "text/event-stream", + // No mcp-session-id header + }, + }); + + await transport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith(400); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); + expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"message":"Bad Request: Mcp-Session-Id header is required"')); + }); + + it("should always include session ID in initialization response even in stateless mode", async () => { + // Create a stateless transport for this test + const statelessTransport = new StreamableHTTPServerTransport(endpoint, { enableSessionManagement: false }); + + // Create an initialization request + const initializeMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json", + }, + body: JSON.stringify(initializeMessage), + }); + + await statelessTransport.handleRequest(req, mockResponse); + + // In stateless mode, session ID should also be included for initialize responses + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "mcp-session-id": statelessTransport.sessionId, + }) + ); + }); + }); + + describe("Stateless Mode", () => { + let statelessTransport: StreamableHTTPServerTransport; + let mockResponse: jest.Mocked; + + beforeEach(() => { + statelessTransport = new StreamableHTTPServerTransport(endpoint, { enableSessionManagement: false }); + mockResponse = createMockResponse(); + }); + + it("should not include session ID in response headers when in stateless mode", async () => { + // Use a non-initialization request + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: 1, + }; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json", + }, + body: JSON.stringify(message), + }); + + await statelessTransport.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalled(); + // Extract the headers from writeHead call + const headers = mockResponse.writeHead.mock.calls[0][1]; + expect(headers).not.toHaveProperty("mcp-session-id"); + }); + + it("should not validate session ID in stateless mode", async () => { + const req = createMockRequest({ + method: "GET", + headers: { + accept: "text/event-stream", + "mcp-session-id": "invalid-session-id", // This would cause a 404 in stateful mode + }, + }); + + await statelessTransport.handleRequest(req, mockResponse); + + // Should still get 200 OK, not 404 Not Found + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.not.objectContaining({ + "mcp-session-id": expect.anything(), + }) + ); + }); + + it("should handle POST requests without session validation in stateless mode", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: 1, + }; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json", + "mcp-session-id": "non-existent-session-id", // This would be rejected in stateful mode + }, + body: JSON.stringify(message), + }); + + const onMessageMock = jest.fn(); + statelessTransport.onmessage = onMessageMock; + + await statelessTransport.handleRequest(req, mockResponse); + + // Message should be processed despite invalid session ID + expect(onMessageMock).toHaveBeenCalledWith(message); + }); + + it("should work with a mix of requests with and without session IDs in stateless mode", async () => { + // First request without session ID + const req1 = createMockRequest({ + method: "GET", + headers: { + accept: "text/event-stream", + }, + }); + + await statelessTransport.handleRequest(req1, mockResponse); + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "Content-Type": "text/event-stream", + }) + ); + + // Reset mock for second request + mockResponse.writeHead.mockClear(); + + // Second request with a session ID (which would be invalid in stateful mode) + const req2 = createMockRequest({ + method: "GET", + headers: { + accept: "text/event-stream", + "mcp-session-id": "some-random-session-id", + }, + }); + + await statelessTransport.handleRequest(req2, mockResponse); + + // Should still succeed + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "Content-Type": "text/event-stream", + }) + ); + }); + + it("should handle initialization requests properly in both modes", async () => { + // Initialize message that would typically be sent during initialization + const initializeMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + // Test stateful transport (default) + const statefulReq = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json", + }, + body: JSON.stringify(initializeMessage), + }); + + await transport.handleRequest(statefulReq, mockResponse); + + // In stateful mode, session ID should be included in the response header + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "mcp-session-id": transport.sessionId, + }) + ); + + // Reset mocks for stateless test + mockResponse.writeHead.mockClear(); + + // Test stateless transport + const statelessReq = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json", + }, + body: JSON.stringify(initializeMessage), + }); + + await statelessTransport.handleRequest(statelessReq, mockResponse); + + // In stateless mode, session ID should also be included for initialize responses + const headers = mockResponse.writeHead.mock.calls[0][1]; + expect(headers).toHaveProperty("mcp-session-id", statelessTransport.sessionId); + }); }); describe("Request Handling", () => { it("should reject GET requests without Accept: text/event-stream header", async () => { const req = createMockRequest({ method: "GET", - headers: {}, + headers: { + "mcp-session-id": transport.sessionId, + }, }); await transport.handleRequest(req, mockResponse); @@ -108,7 +350,8 @@ describe("StreamableHTTPServerTransport", () => { const req = createMockRequest({ method: "GET", headers: { - accept: "text/event-stream", + "accept": "text/event-stream", + "mcp-session-id": transport.sessionId, }, }); @@ -127,7 +370,7 @@ describe("StreamableHTTPServerTransport", () => { it("should reject POST requests without proper Accept header", async () => { const message: JSONRPCMessage = { jsonrpc: "2.0", - method: "test", + method: "initialize", // Use initialize to bypass session ID check params: {}, id: 1, }; @@ -148,7 +391,7 @@ describe("StreamableHTTPServerTransport", () => { it("should properly handle JSON-RPC request messages in POST requests", async () => { const message: JSONRPCMessage = { jsonrpc: "2.0", - method: "test", + method: "initialize", // Use initialize to bypass session ID check params: {}, id: 1, }; @@ -188,6 +431,7 @@ describe("StreamableHTTPServerTransport", () => { headers: { "content-type": "application/json", "accept": "application/json, text/event-stream", + "mcp-session-id": transport.sessionId, }, body: JSON.stringify(notification), }); @@ -212,6 +456,7 @@ describe("StreamableHTTPServerTransport", () => { headers: { "content-type": "application/json", "accept": "application/json", + "mcp-session-id": transport.sessionId, }, body: JSON.stringify(batchMessages), }); @@ -231,6 +476,7 @@ describe("StreamableHTTPServerTransport", () => { headers: { "content-type": "text/plain", "accept": "application/json", + "mcp-session-id": transport.sessionId, }, body: "test", }); @@ -244,7 +490,9 @@ describe("StreamableHTTPServerTransport", () => { it("should properly handle DELETE requests and close session", async () => { const req = createMockRequest({ method: "DELETE", - headers: {}, + headers: { + "mcp-session-id": transport.sessionId, + }, }); const onCloseMock = jest.fn(); @@ -259,11 +507,12 @@ describe("StreamableHTTPServerTransport", () => { describe("Message Replay", () => { it("should replay messages after specified Last-Event-ID", async () => { - // Establish first connection with Accept header + // Establish first connection with Accept header and session ID const req1 = createMockRequest({ method: "GET", headers: { - "accept": "text/event-stream" + "accept": "text/event-stream", + "mcp-session-id": transport.sessionId }, }); await transport.handleRequest(req1, mockResponse); @@ -293,6 +542,7 @@ describe("StreamableHTTPServerTransport", () => { headers: { "accept": "text/event-stream", "last-event-id": lastEventId, + "mcp-session-id": transport.sessionId }, }); diff --git a/src/server/streamable-http.ts b/src/server/streamable-http.ts index bf0096d1..73eae32b 100644 --- a/src/server/streamable-http.ts +++ b/src/server/streamable-http.ts @@ -18,9 +18,44 @@ interface StreamConnection { requestId?: string | null; } +/** + * Configuration options for StreamableHTTPServerTransport + */ +export interface StreamableHTTPServerTransportOptions { + /** + * Whether to enable session management through mcp-session-id headers + * When set to false, the transport operates in stateless mode without session validation + * @default true + */ + enableSessionManagement?: boolean; +} + /** * Server transport for Streamable HTTP: this implements the MCP Streamable HTTP transport specification. * It supports both SSE streaming and direct HTTP responses, with session management and message resumability. + * + * Usage example: + * + * ```typescript + * // Stateful mode (default) - with session management + * const statefulTransport = new StreamableHTTPServerTransport("/mcp"); + * + * // Stateless mode - without session management + * const statelessTransport = new StreamableHTTPServerTransport("/mcp", { + * enableSessionManagement: false + * }); + * ``` + * + * In stateful mode: + * - Session ID is generated and included in response headers + * - Session ID is always included in initialization responses + * - Requests with invalid session IDs are rejected with 404 Not Found + * - Non-initialization requests without a session ID are rejected with 400 Bad Request + * - State is maintained in-memory (connections, message history) + * + * In stateless mode: + * - Session ID is only included in initialization responses + * - No session validation is performed */ export class StreamableHTTPServerTransport implements Transport { private _connections: Map = new Map(); @@ -31,13 +66,15 @@ export class StreamableHTTPServerTransport implements Transport { }> = new Map(); private _started: boolean = false; private _requestConnections: Map = new Map(); // request ID to connection ID mapping + private _enableSessionManagement: boolean; onclose?: () => void; onerror?: (error: Error) => void; onmessage?: (message: JSONRPCMessage) => void; - constructor(private _endpoint: string) { + constructor(private _endpoint: string, options?: StreamableHTTPServerTransportOptions) { this._sessionId = randomUUID(); + this._enableSessionManagement = options?.enableSessionManagement !== false; } /** @@ -55,18 +92,41 @@ export class StreamableHTTPServerTransport implements Transport { * Handles an incoming HTTP request, whether GET or POST */ async handleRequest(req: IncomingMessage, res: ServerResponse): Promise { - // validate the session ID - const sessionId = req.headers["mcp-session-id"]; - if (sessionId && (Array.isArray(sessionId) ? sessionId[0] : sessionId) !== this._sessionId) { - res.writeHead(404).end(JSON.stringify({ - jsonrpc: "2.0", - error: { - code: -32001, - message: "Session not found" - }, - id: null - })); - return; + // Only validate session ID for non-initialization requests when session management is enabled + if (this._enableSessionManagement) { + const sessionId = req.headers["mcp-session-id"]; + + // Check if this might be an initialization request + const isInitializationRequest = req.method === "POST" && + req.headers["content-type"]?.includes("application/json"); + + if (isInitializationRequest) { + // For POST requests with JSON content, we need to check if it's an initialization request + // This will be done in handlePostRequest, as we need to parse the body + // Continue processing normally + } else if (!sessionId) { + // Non-initialization requests without a session ID should return 400 Bad Request + res.writeHead(400).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: "Bad Request: Mcp-Session-Id header is required" + }, + id: null + })); + return; + } else if ((Array.isArray(sessionId) ? sessionId[0] : sessionId) !== this._sessionId) { + // Reject requests with invalid session ID with 404 Not Found + res.writeHead(404).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32001, + message: "Session not found" + }, + id: null + })); + return; + } } if (req.method === "GET") { @@ -109,12 +169,19 @@ export class StreamableHTTPServerTransport implements Transport { const lastEventId = req.headers["last-event-id"]; const lastEventIdStr = Array.isArray(lastEventId) ? lastEventId[0] : lastEventId; - res.writeHead(200, { + // Prepare response headers + const headers: Record = { "Content-Type": "text/event-stream", "Cache-Control": "no-cache", Connection: "keep-alive", - "mcp-session-id": this._sessionId, - }); + }; + + // Only include session ID header if session management is enabled + if (this._enableSessionManagement) { + headers["mcp-session-id"] = this._sessionId; + } + + res.writeHead(200, headers); const connection: StreamConnection = { response: res, @@ -192,6 +259,12 @@ export class StreamableHTTPServerTransport implements Transport { messages = [JSONRPCMessageSchema.parse(rawMessage)]; } + // Check if this is an initialization request + // https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/lifecycle/ + const isInitializationRequest = messages.some( + msg => 'method' in msg && msg.method === 'initialize' && 'id' in msg + ); + // check if it contains requests const hasRequests = messages.some(msg => 'method' in msg && 'id' in msg); const hasOnlyNotificationsOrResponses = messages.every(msg => @@ -210,12 +283,19 @@ export class StreamableHTTPServerTransport implements Transport { const useSSE = acceptHeader.includes("text/event-stream"); if (useSSE) { - res.writeHead(200, { + const headers: Record = { "Content-Type": "text/event-stream", "Cache-Control": "no-cache", Connection: "keep-alive", - "mcp-session-id": this._sessionId, - }); + }; + + // Only include session ID header if session management is enabled + // Always include session ID for initialization requests + if (this._enableSessionManagement || isInitializationRequest) { + headers["mcp-session-id"] = this._sessionId; + } + + res.writeHead(200, headers); const connectionId = randomUUID(); const connection: StreamConnection = { @@ -247,10 +327,17 @@ export class StreamableHTTPServerTransport implements Transport { }); } else { // use direct JSON response - res.writeHead(200, { + const headers: Record = { "Content-Type": "application/json", - "mcp-session-id": this._sessionId, - }); + }; + + // Only include session ID header if session management is enabled + // Always include session ID for initialization requests + if (this._enableSessionManagement || isInitializationRequest) { + headers["mcp-session-id"] = this._sessionId; + } + + res.writeHead(200, headers); // handle each message for (const message of messages) { From b8ad23fd7b16c848029baed8032e1748ee18c70b Mon Sep 17 00:00:00 2001 From: Geng Yan Date: Wed, 2 Apr 2025 08:57:21 +0800 Subject: [PATCH 4/5] =?UTF-8?q?=E2=9C=A8=20feat:=20add=20pre-rarsed=20body?= =?UTF-8?q?=20=20support?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/server/streamable-http.test.ts | 114 +++++++++++++++++++++++++++++ src/server/streamable-http.ts | 28 ++++--- 2 files changed, 133 insertions(+), 9 deletions(-) diff --git a/src/server/streamable-http.test.ts b/src/server/streamable-http.test.ts index 960d3800..13f70b59 100644 --- a/src/server/streamable-http.test.ts +++ b/src/server/streamable-http.test.ts @@ -674,4 +674,118 @@ describe("StreamableHTTPServerTransport", () => { expect(onErrorMock).toHaveBeenCalled(); }); }); + + describe("Handling Pre-Parsed Body", () => { + it("should accept pre-parsed request body", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "pre-parsed-test", + }; + + // Create a request without actual body content + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json", + }, + // No body provided here - it will be passed as parsedBody + }); + + const onMessageMock = jest.fn(); + transport.onmessage = onMessageMock; + + // Pass the pre-parsed body directly + await transport.handleRequest(req, mockResponse, message); + + // Verify the message was processed correctly + expect(onMessageMock).toHaveBeenCalledWith(message); + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "Content-Type": "application/json", + }) + ); + }); + + it("should handle pre-parsed batch messages", async () => { + const batchMessages: JSONRPCMessage[] = [ + { + jsonrpc: "2.0", + method: "method1", + params: { data: "test1" }, + id: "batch1" + }, + { + jsonrpc: "2.0", + method: "method2", + params: { data: "test2" }, + id: "batch2" + }, + ]; + + // Create a request without actual body content + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "text/event-stream", + "mcp-session-id": transport.sessionId, + }, + // No body provided here - it will be passed as parsedBody + }); + + const onMessageMock = jest.fn(); + transport.onmessage = onMessageMock; + + // Pass the pre-parsed body directly + await transport.handleRequest(req, mockResponse, batchMessages); + + // Should be called for each message in the batch + expect(onMessageMock).toHaveBeenCalledTimes(2); + expect(onMessageMock).toHaveBeenCalledWith(batchMessages[0]); + expect(onMessageMock).toHaveBeenCalledWith(batchMessages[1]); + }); + + it("should prefer pre-parsed body over request body", async () => { + const requestBodyMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "fromRequestBody", + params: {}, + id: "request-body", + }; + + const parsedBodyMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "fromParsedBody", + params: {}, + id: "parsed-body", + }; + + // Create a request with actual body content + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json", + }, + body: JSON.stringify(requestBodyMessage), + }); + + const onMessageMock = jest.fn(); + transport.onmessage = onMessageMock; + + // Pass the pre-parsed body directly + await transport.handleRequest(req, mockResponse, parsedBodyMessage); + + // Should use the parsed body instead of the request body + expect(onMessageMock).toHaveBeenCalledWith(parsedBodyMessage); + expect(onMessageMock).not.toHaveBeenCalledWith(requestBodyMessage); + }); + }); }); \ No newline at end of file diff --git a/src/server/streamable-http.ts b/src/server/streamable-http.ts index 73eae32b..bec0a1ad 100644 --- a/src/server/streamable-http.ts +++ b/src/server/streamable-http.ts @@ -44,6 +44,11 @@ export interface StreamableHTTPServerTransportOptions { * const statelessTransport = new StreamableHTTPServerTransport("/mcp", { * enableSessionManagement: false * }); + * + * // Using with pre-parsed request body + * app.post('/mcp', (req, res) => { + * transport.handleRequest(req, res, req.body); + * }); * ``` * * In stateful mode: @@ -91,7 +96,7 @@ export class StreamableHTTPServerTransport implements Transport { /** * Handles an incoming HTTP request, whether GET or POST */ - async handleRequest(req: IncomingMessage, res: ServerResponse): Promise { + async handleRequest(req: IncomingMessage, res: ServerResponse, parsedBody?: unknown): Promise { // Only validate session ID for non-initialization requests when session management is enabled if (this._enableSessionManagement) { const sessionId = req.headers["mcp-session-id"]; @@ -132,7 +137,7 @@ export class StreamableHTTPServerTransport implements Transport { if (req.method === "GET") { await this.handleGetRequest(req, res); } else if (req.method === "POST") { - await this.handlePostRequest(req, res); + await this.handlePostRequest(req, res, parsedBody); } else if (req.method === "DELETE") { await this.handleDeleteRequest(req, res); } else { @@ -213,7 +218,7 @@ export class StreamableHTTPServerTransport implements Transport { /** * Handles POST requests containing JSON-RPC messages */ - private async handlePostRequest(req: IncomingMessage, res: ServerResponse): Promise { + private async handlePostRequest(req: IncomingMessage, res: ServerResponse, parsedBody?: unknown): Promise { try { // validate the Accept header const acceptHeader = req.headers.accept; @@ -243,13 +248,18 @@ export class StreamableHTTPServerTransport implements Transport { return; } - const parsedCt = contentType.parse(ct); - const body = await getRawBody(req, { - limit: MAXIMUM_MESSAGE_SIZE, - encoding: parsedCt.parameters.charset ?? "utf-8", - }); + let rawMessage; + if (parsedBody !== undefined) { + rawMessage = parsedBody; + } else { + const parsedCt = contentType.parse(ct); + const body = await getRawBody(req, { + limit: MAXIMUM_MESSAGE_SIZE, + encoding: parsedCt.parameters.charset ?? "utf-8", + }); + rawMessage = JSON.parse(body.toString()); + } - const rawMessage = JSON.parse(body.toString()); let messages: JSONRPCMessage[]; // handle batch and single messages From 614e3d211b8d5dd7251798d562bdc513a125dda2 Mon Sep 17 00:00:00 2001 From: Geng Yan Date: Thu, 3 Apr 2025 08:44:29 +0800 Subject: [PATCH 5/5] =?UTF-8?q?=E2=9C=A8=20feat:=20add=20customHeaders=20o?= =?UTF-8?q?ptions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/server/streamable-http.test.ts | 143 ++++++++++++++++++++++++++++- src/server/streamable-http.ts | 19 +++- 2 files changed, 155 insertions(+), 7 deletions(-) diff --git a/src/server/streamable-http.test.ts b/src/server/streamable-http.test.ts index 13f70b59..74b92c1a 100644 --- a/src/server/streamable-http.test.ts +++ b/src/server/streamable-http.test.ts @@ -98,7 +98,7 @@ describe("StreamableHTTPServerTransport", () => { await transport.handleRequest(req, mockResponse); - expect(mockResponse.writeHead).toHaveBeenCalledWith(404); + expect(mockResponse.writeHead).toHaveBeenCalledWith(404, {}); // check if the error response is a valid JSON-RPC error format expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"error"')); @@ -115,7 +115,7 @@ describe("StreamableHTTPServerTransport", () => { await transport.handleRequest(req, mockResponse); - expect(mockResponse.writeHead).toHaveBeenCalledWith(400); + expect(mockResponse.writeHead).toHaveBeenCalledWith(400, {}); expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"message":"Bad Request: Mcp-Session-Id header is required"')); }); @@ -342,7 +342,7 @@ describe("StreamableHTTPServerTransport", () => { await transport.handleRequest(req, mockResponse); - expect(mockResponse.writeHead).toHaveBeenCalledWith(406); + expect(mockResponse.writeHead).toHaveBeenCalledWith(406, {}); expect(mockResponse.end).toHaveBeenCalledWith(expect.stringContaining('"jsonrpc":"2.0"')); }); @@ -788,4 +788,141 @@ describe("StreamableHTTPServerTransport", () => { expect(onMessageMock).not.toHaveBeenCalledWith(requestBodyMessage); }); }); + + describe("Custom Headers", () => { + const customHeaders = { + "X-Custom-Header": "custom-value", + "X-API-Version": "1.0", + "Access-Control-Allow-Origin": "*" + }; + + let transportWithHeaders: StreamableHTTPServerTransport; + let mockResponse: jest.Mocked; + + beforeEach(() => { + transportWithHeaders = new StreamableHTTPServerTransport(endpoint, { customHeaders }); + mockResponse = createMockResponse(); + }); + + it("should include custom headers in SSE response", async () => { + const req = createMockRequest({ + method: "GET", + headers: { + accept: "text/event-stream", + "mcp-session-id": transportWithHeaders.sessionId + }, + }); + + await transportWithHeaders.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + ...customHeaders, + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "mcp-session-id": transportWithHeaders.sessionId + }) + ); + }); + + it("should include custom headers in JSON response", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: 1, + }; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json", + "mcp-session-id": transportWithHeaders.sessionId + }, + body: JSON.stringify(message), + }); + + await transportWithHeaders.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + ...customHeaders, + "Content-Type": "application/json", + "mcp-session-id": transportWithHeaders.sessionId + }) + ); + }); + + it("should include custom headers in error responses", async () => { + const req = createMockRequest({ + method: "GET", + headers: { + accept: "text/event-stream", + "mcp-session-id": "invalid-session-id" + }, + }); + + await transportWithHeaders.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 404, + expect.objectContaining(customHeaders) + ); + }); + + it("should not override essential headers with custom headers", async () => { + const transportWithConflictingHeaders = new StreamableHTTPServerTransport(endpoint, { + customHeaders: { + "Content-Type": "text/plain", // 尝试覆盖必要的 Content-Type 头 + "X-Custom-Header": "custom-value" + } + }); + + const req = createMockRequest({ + method: "GET", + headers: { + accept: "text/event-stream", + "mcp-session-id": transportWithConflictingHeaders.sessionId + }, + }); + + await transportWithConflictingHeaders.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "Content-Type": "text/event-stream", // 应该保持原有的 Content-Type + "X-Custom-Header": "custom-value" + }) + ); + }); + + it("should work with empty custom headers", async () => { + const transportWithoutHeaders = new StreamableHTTPServerTransport(endpoint); + + const req = createMockRequest({ + method: "GET", + headers: { + accept: "text/event-stream", + "mcp-session-id": transportWithoutHeaders.sessionId + }, + }); + + await transportWithoutHeaders.handleRequest(req, mockResponse); + + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "mcp-session-id": transportWithoutHeaders.sessionId + }) + ); + }); + }); }); \ No newline at end of file diff --git a/src/server/streamable-http.ts b/src/server/streamable-http.ts index bec0a1ad..13d28968 100644 --- a/src/server/streamable-http.ts +++ b/src/server/streamable-http.ts @@ -28,6 +28,12 @@ export interface StreamableHTTPServerTransportOptions { * @default true */ enableSessionManagement?: boolean; + + /** + * Custom headers to be included in all responses + * These headers will be added to both SSE and regular HTTP responses + */ + customHeaders?: Record; } /** @@ -72,6 +78,7 @@ export class StreamableHTTPServerTransport implements Transport { private _started: boolean = false; private _requestConnections: Map = new Map(); // request ID to connection ID mapping private _enableSessionManagement: boolean; + private _customHeaders: Record; onclose?: () => void; onerror?: (error: Error) => void; @@ -80,6 +87,7 @@ export class StreamableHTTPServerTransport implements Transport { constructor(private _endpoint: string, options?: StreamableHTTPServerTransportOptions) { this._sessionId = randomUUID(); this._enableSessionManagement = options?.enableSessionManagement !== false; + this._customHeaders = options?.customHeaders || {}; } /** @@ -111,7 +119,7 @@ export class StreamableHTTPServerTransport implements Transport { // Continue processing normally } else if (!sessionId) { // Non-initialization requests without a session ID should return 400 Bad Request - res.writeHead(400).end(JSON.stringify({ + res.writeHead(400, this._customHeaders).end(JSON.stringify({ jsonrpc: "2.0", error: { code: -32000, @@ -122,7 +130,7 @@ export class StreamableHTTPServerTransport implements Transport { return; } else if ((Array.isArray(sessionId) ? sessionId[0] : sessionId) !== this._sessionId) { // Reject requests with invalid session ID with 404 Not Found - res.writeHead(404).end(JSON.stringify({ + res.writeHead(404, this._customHeaders).end(JSON.stringify({ jsonrpc: "2.0", error: { code: -32001, @@ -141,7 +149,7 @@ export class StreamableHTTPServerTransport implements Transport { } else if (req.method === "DELETE") { await this.handleDeleteRequest(req, res); } else { - res.writeHead(405).end(JSON.stringify({ + res.writeHead(405, this._customHeaders).end(JSON.stringify({ jsonrpc: "2.0", error: { code: -32000, @@ -159,7 +167,7 @@ export class StreamableHTTPServerTransport implements Transport { // validate the Accept header const acceptHeader = req.headers.accept; if (!acceptHeader || !acceptHeader.includes("text/event-stream")) { - res.writeHead(406).end(JSON.stringify({ + res.writeHead(406, this._customHeaders).end(JSON.stringify({ jsonrpc: "2.0", error: { code: -32000, @@ -176,6 +184,7 @@ export class StreamableHTTPServerTransport implements Transport { // Prepare response headers const headers: Record = { + ...this._customHeaders, "Content-Type": "text/event-stream", "Cache-Control": "no-cache", Connection: "keep-alive", @@ -294,6 +303,7 @@ export class StreamableHTTPServerTransport implements Transport { if (useSSE) { const headers: Record = { + ...this._customHeaders, "Content-Type": "text/event-stream", "Cache-Control": "no-cache", Connection: "keep-alive", @@ -338,6 +348,7 @@ export class StreamableHTTPServerTransport implements Transport { } else { // use direct JSON response const headers: Record = { + ...this._customHeaders, "Content-Type": "application/json", };