diff --git a/src/server/sse.test.ts b/src/server/sse.test.ts new file mode 100644 index 00000000..2fd2c042 --- /dev/null +++ b/src/server/sse.test.ts @@ -0,0 +1,109 @@ +import http from 'http'; +import { jest } from '@jest/globals'; +import { SSEServerTransport } from './sse.js'; + +const createMockResponse = () => { + const res = { + writeHead: jest.fn(), + write: jest.fn().mockReturnValue(true), + on: jest.fn(), + }; + res.writeHead.mockReturnThis(); + res.on.mockReturnThis(); + + return res as unknown as http.ServerResponse; +}; + +describe('SSEServerTransport', () => { + describe('start method', () => { + it('should correctly append sessionId to a simple relative endpoint', async () => { + const mockRes = createMockResponse(); + const endpoint = '/messages'; + const transport = new SSEServerTransport(endpoint, mockRes); + const expectedSessionId = transport.sessionId; + + await transport.start(); + + expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object)); + expect(mockRes.write).toHaveBeenCalledTimes(1); + expect(mockRes.write).toHaveBeenCalledWith( + `event: endpoint\ndata: /messages?sessionId=${expectedSessionId}\n\n` + ); + }); + + it('should correctly append sessionId to an endpoint with existing query parameters', async () => { + const mockRes = createMockResponse(); + const endpoint = '/messages?foo=bar&baz=qux'; + const transport = new SSEServerTransport(endpoint, mockRes); + const expectedSessionId = transport.sessionId; + + await transport.start(); + + expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object)); + expect(mockRes.write).toHaveBeenCalledTimes(1); + expect(mockRes.write).toHaveBeenCalledWith( + `event: endpoint\ndata: /messages?foo=bar&baz=qux&sessionId=${expectedSessionId}\n\n` + ); + }); + + it('should correctly append sessionId to an endpoint with a hash fragment', async () => { + const mockRes = createMockResponse(); + const endpoint = '/messages#section1'; + const transport = new SSEServerTransport(endpoint, mockRes); + const expectedSessionId = transport.sessionId; + + await transport.start(); + + expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object)); + expect(mockRes.write).toHaveBeenCalledTimes(1); + expect(mockRes.write).toHaveBeenCalledWith( + `event: endpoint\ndata: /messages?sessionId=${expectedSessionId}#section1\n\n` + ); + }); + + it('should correctly append sessionId to an endpoint with query parameters and a hash fragment', async () => { + const mockRes = createMockResponse(); + const endpoint = '/messages?key=value#section2'; + const transport = new SSEServerTransport(endpoint, mockRes); + const expectedSessionId = transport.sessionId; + + await transport.start(); + + expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object)); + expect(mockRes.write).toHaveBeenCalledTimes(1); + expect(mockRes.write).toHaveBeenCalledWith( + `event: endpoint\ndata: /messages?key=value&sessionId=${expectedSessionId}#section2\n\n` + ); + }); + + it('should correctly handle the root path endpoint "/"', async () => { + const mockRes = createMockResponse(); + const endpoint = '/'; + const transport = new SSEServerTransport(endpoint, mockRes); + const expectedSessionId = transport.sessionId; + + await transport.start(); + + expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object)); + expect(mockRes.write).toHaveBeenCalledTimes(1); + expect(mockRes.write).toHaveBeenCalledWith( + `event: endpoint\ndata: /?sessionId=${expectedSessionId}\n\n` + ); + }); + + it('should correctly handle an empty string endpoint ""', async () => { + const mockRes = createMockResponse(); + const endpoint = ''; + const transport = new SSEServerTransport(endpoint, mockRes); + const expectedSessionId = transport.sessionId; + + await transport.start(); + + expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object)); + expect(mockRes.write).toHaveBeenCalledTimes(1); + expect(mockRes.write).toHaveBeenCalledWith( + `event: endpoint\ndata: /?sessionId=${expectedSessionId}\n\n` + ); + }); + }); +}); diff --git a/src/server/sse.ts b/src/server/sse.ts index 84c1cbb9..701b9e91 100644 --- a/src/server/sse.ts +++ b/src/server/sse.ts @@ -4,6 +4,7 @@ import { Transport } from "../shared/transport.js"; import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; import getRawBody from "raw-body"; import contentType from "content-type"; +import { URL } from 'url'; const MAXIMUM_MESSAGE_SIZE = "4mb"; @@ -49,8 +50,17 @@ export class SSEServerTransport implements Transport { }); // Send the endpoint event + // Use a dummy base URL because this._endpoint is relative. + // This allows using URL/URLSearchParams for robust parameter handling. + const dummyBase = 'http://localhost'; // Any valid base works + const endpointUrl = new URL(this._endpoint, dummyBase); + endpointUrl.searchParams.set('sessionId', this._sessionId); + + // Reconstruct the relative URL string (pathname + search + hash) + const relativeUrlWithSession = endpointUrl.pathname + endpointUrl.search + endpointUrl.hash; + this.res.write( - `event: endpoint\ndata: ${encodeURI(this._endpoint)}?sessionId=${this._sessionId}\n\n`, + `event: endpoint\ndata: ${relativeUrlWithSession}\n\n`, ); this._sseResponse = this.res;