Skip to content

Improve SSE endpoint sessionId parameter handling #177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions src/server/sse.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import http from 'http';
import { jest } from '@jest/globals';
import { SSEServerTransport } from './sse.js';

const createMockResponse = () => {
const res = {
writeHead: jest.fn<http.ServerResponse['writeHead']>(),
write: jest.fn<http.ServerResponse['write']>().mockReturnValue(true),
on: jest.fn<http.ServerResponse['on']>(),
};
res.writeHead.mockReturnThis();
res.on.mockReturnThis();

return res as unknown as http.ServerResponse;
};

describe('SSEServerTransport', () => {
describe('start method', () => {
it('should correctly append sessionId to a simple relative endpoint', async () => {
const mockRes = createMockResponse();
const endpoint = '/messages';
const transport = new SSEServerTransport(endpoint, mockRes);
const expectedSessionId = transport.sessionId;

await transport.start();

expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object));
expect(mockRes.write).toHaveBeenCalledTimes(1);
expect(mockRes.write).toHaveBeenCalledWith(
`event: endpoint\ndata: /messages?sessionId=${expectedSessionId}\n\n`
);
});

it('should correctly append sessionId to an endpoint with existing query parameters', async () => {
const mockRes = createMockResponse();
const endpoint = '/messages?foo=bar&baz=qux';
const transport = new SSEServerTransport(endpoint, mockRes);
const expectedSessionId = transport.sessionId;

await transport.start();

expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object));
expect(mockRes.write).toHaveBeenCalledTimes(1);
expect(mockRes.write).toHaveBeenCalledWith(
`event: endpoint\ndata: /messages?foo=bar&baz=qux&sessionId=${expectedSessionId}\n\n`
);
});

it('should correctly append sessionId to an endpoint with a hash fragment', async () => {
const mockRes = createMockResponse();
const endpoint = '/messages#section1';
const transport = new SSEServerTransport(endpoint, mockRes);
const expectedSessionId = transport.sessionId;

await transport.start();

expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object));
expect(mockRes.write).toHaveBeenCalledTimes(1);
expect(mockRes.write).toHaveBeenCalledWith(
`event: endpoint\ndata: /messages?sessionId=${expectedSessionId}#section1\n\n`
);
});

it('should correctly append sessionId to an endpoint with query parameters and a hash fragment', async () => {
const mockRes = createMockResponse();
const endpoint = '/messages?key=value#section2';
const transport = new SSEServerTransport(endpoint, mockRes);
const expectedSessionId = transport.sessionId;

await transport.start();

expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object));
expect(mockRes.write).toHaveBeenCalledTimes(1);
expect(mockRes.write).toHaveBeenCalledWith(
`event: endpoint\ndata: /messages?key=value&sessionId=${expectedSessionId}#section2\n\n`
);
});

it('should correctly handle the root path endpoint "/"', async () => {
const mockRes = createMockResponse();
const endpoint = '/';
const transport = new SSEServerTransport(endpoint, mockRes);
const expectedSessionId = transport.sessionId;

await transport.start();

expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object));
expect(mockRes.write).toHaveBeenCalledTimes(1);
expect(mockRes.write).toHaveBeenCalledWith(
`event: endpoint\ndata: /?sessionId=${expectedSessionId}\n\n`
);
});

it('should correctly handle an empty string endpoint ""', async () => {
const mockRes = createMockResponse();
const endpoint = '';
const transport = new SSEServerTransport(endpoint, mockRes);
const expectedSessionId = transport.sessionId;

await transport.start();

expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object));
expect(mockRes.write).toHaveBeenCalledTimes(1);
expect(mockRes.write).toHaveBeenCalledWith(
`event: endpoint\ndata: /?sessionId=${expectedSessionId}\n\n`
);
});
});
});
12 changes: 11 additions & 1 deletion src/server/sse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { Transport } from "../shared/transport.js";
import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
import getRawBody from "raw-body";
import contentType from "content-type";
import { URL } from 'url';

const MAXIMUM_MESSAGE_SIZE = "4mb";

Expand Down Expand Up @@ -49,8 +50,17 @@ export class SSEServerTransport implements Transport {
});

// Send the endpoint event
// Use a dummy base URL because this._endpoint is relative.
// This allows using URL/URLSearchParams for robust parameter handling.
const dummyBase = 'http://localhost'; // Any valid base works
const endpointUrl = new URL(this._endpoint, dummyBase);
endpointUrl.searchParams.set('sessionId', this._sessionId);

// Reconstruct the relative URL string (pathname + search + hash)
const relativeUrlWithSession = endpointUrl.pathname + endpointUrl.search + endpointUrl.hash;

this.res.write(
`event: endpoint\ndata: ${encodeURI(this._endpoint)}?sessionId=${this._sessionId}\n\n`,
`event: endpoint\ndata: ${relativeUrlWithSession}\n\n`,
);

this._sseResponse = this.res;
Expand Down