Skip to content

Commit 4ecc955

Browse files
Merge pull request modelcontextprotocol#134 from modelcontextprotocol/justin/sse-auth
Use `eventsource` package, to permit custom headers for SSE
2 parents 9908919 + 61a42d2 commit 4ecc955

File tree

5 files changed

+318
-17
lines changed

5 files changed

+318
-17
lines changed

Diff for: package-lock.json

+20-8
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
},
4848
"dependencies": {
4949
"content-type": "^1.0.5",
50+
"eventsource": "^3.0.2",
5051
"raw-body": "^3.0.0",
5152
"zod": "^3.23.8",
5253
"zod-to-json-schema": "^3.24.1"
@@ -61,7 +62,6 @@
6162
"@types/node": "^22.0.2",
6263
"@types/ws": "^8.5.12",
6364
"eslint": "^9.8.0",
64-
"eventsource": "^2.0.2",
6565
"express": "^4.19.2",
6666
"jest": "^29.7.0",
6767
"ts-jest": "^29.2.4",

Diff for: src/cli.ts

-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
import EventSource from "eventsource";
21
import WebSocket from "ws";
32

4-
// eslint-disable-next-line @typescript-eslint/no-explicit-any
5-
(global as any).EventSource = EventSource;
63
// eslint-disable-next-line @typescript-eslint/no-explicit-any
74
(global as any).WebSocket = WebSocket;
85

Diff for: src/client/sse.test.ts

+287
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
import { createServer, type IncomingMessage, type Server } from "http";
2+
import { AddressInfo } from "net";
3+
import { JSONRPCMessage } from "../types.js";
4+
import { SSEClientTransport } from "./sse.js";
5+
6+
describe("SSEClientTransport", () => {
7+
let server: Server;
8+
let transport: SSEClientTransport;
9+
let baseUrl: URL;
10+
let lastServerRequest: IncomingMessage;
11+
let sendServerMessage: ((message: string) => void) | null = null;
12+
13+
beforeEach((done) => {
14+
// Reset state
15+
lastServerRequest = null as unknown as IncomingMessage;
16+
sendServerMessage = null;
17+
18+
// Create a test server that will receive the EventSource connection
19+
server = createServer((req, res) => {
20+
lastServerRequest = req;
21+
22+
// Send SSE headers
23+
res.writeHead(200, {
24+
"Content-Type": "text/event-stream",
25+
"Cache-Control": "no-cache",
26+
Connection: "keep-alive",
27+
});
28+
29+
// Send the endpoint event
30+
res.write("event: endpoint\n");
31+
res.write(`data: ${baseUrl.href}\n\n`);
32+
33+
// Store reference to send function for tests
34+
sendServerMessage = (message: string) => {
35+
res.write(`data: ${message}\n\n`);
36+
};
37+
38+
// Handle request body for POST endpoints
39+
if (req.method === "POST") {
40+
let body = "";
41+
req.on("data", (chunk) => {
42+
body += chunk;
43+
});
44+
req.on("end", () => {
45+
(req as IncomingMessage & { body: string }).body = body;
46+
res.end();
47+
});
48+
}
49+
});
50+
51+
// Start server on random port
52+
server.listen(0, "127.0.0.1", () => {
53+
const addr = server.address() as AddressInfo;
54+
baseUrl = new URL(`http://127.0.0.1:${addr.port}`);
55+
done();
56+
});
57+
});
58+
59+
afterEach(async () => {
60+
await transport.close();
61+
await server.close();
62+
});
63+
64+
describe("connection handling", () => {
65+
it("establishes SSE connection and receives endpoint", async () => {
66+
transport = new SSEClientTransport(baseUrl);
67+
await transport.start();
68+
69+
expect(lastServerRequest.headers.accept).toBe("text/event-stream");
70+
expect(lastServerRequest.method).toBe("GET");
71+
});
72+
73+
it("rejects if server returns non-200 status", async () => {
74+
// Create a server that returns 403
75+
server.close();
76+
await new Promise((resolve) => server.on("close", resolve));
77+
78+
server = createServer((req, res) => {
79+
res.writeHead(403);
80+
res.end();
81+
});
82+
83+
await new Promise<void>((resolve) => {
84+
server.listen(0, "127.0.0.1", () => {
85+
const addr = server.address() as AddressInfo;
86+
baseUrl = new URL(`http://127.0.0.1:${addr.port}`);
87+
resolve();
88+
});
89+
});
90+
91+
transport = new SSEClientTransport(baseUrl);
92+
await expect(transport.start()).rejects.toThrow();
93+
});
94+
95+
it("closes EventSource connection on close()", async () => {
96+
transport = new SSEClientTransport(baseUrl);
97+
await transport.start();
98+
99+
const closePromise = new Promise((resolve) => {
100+
lastServerRequest.on("close", resolve);
101+
});
102+
103+
await transport.close();
104+
await closePromise;
105+
});
106+
});
107+
108+
describe("message handling", () => {
109+
it("receives and parses JSON-RPC messages", async () => {
110+
const receivedMessages: JSONRPCMessage[] = [];
111+
transport = new SSEClientTransport(baseUrl);
112+
transport.onmessage = (msg) => receivedMessages.push(msg);
113+
114+
await transport.start();
115+
116+
const testMessage: JSONRPCMessage = {
117+
jsonrpc: "2.0",
118+
id: "test-1",
119+
method: "test",
120+
params: { foo: "bar" },
121+
};
122+
123+
sendServerMessage!(JSON.stringify(testMessage));
124+
125+
// Wait for message processing
126+
await new Promise((resolve) => setTimeout(resolve, 50));
127+
128+
expect(receivedMessages).toHaveLength(1);
129+
expect(receivedMessages[0]).toEqual(testMessage);
130+
});
131+
132+
it("handles malformed JSON messages", async () => {
133+
const errors: Error[] = [];
134+
transport = new SSEClientTransport(baseUrl);
135+
transport.onerror = (err) => errors.push(err);
136+
137+
await transport.start();
138+
139+
sendServerMessage!("invalid json");
140+
141+
// Wait for message processing
142+
await new Promise((resolve) => setTimeout(resolve, 50));
143+
144+
expect(errors).toHaveLength(1);
145+
expect(errors[0].message).toMatch(/JSON/);
146+
});
147+
148+
it("handles messages via POST requests", async () => {
149+
transport = new SSEClientTransport(baseUrl);
150+
await transport.start();
151+
152+
const testMessage: JSONRPCMessage = {
153+
jsonrpc: "2.0",
154+
id: "test-1",
155+
method: "test",
156+
params: { foo: "bar" },
157+
};
158+
159+
await transport.send(testMessage);
160+
161+
// Wait for request processing
162+
await new Promise((resolve) => setTimeout(resolve, 50));
163+
164+
expect(lastServerRequest.method).toBe("POST");
165+
expect(lastServerRequest.headers["content-type"]).toBe(
166+
"application/json",
167+
);
168+
expect(
169+
JSON.parse(
170+
(lastServerRequest as IncomingMessage & { body: string }).body,
171+
),
172+
).toEqual(testMessage);
173+
});
174+
175+
it("handles POST request failures", async () => {
176+
// Create a server that returns 500 for POST
177+
server.close();
178+
await new Promise((resolve) => server.on("close", resolve));
179+
180+
server = createServer((req, res) => {
181+
if (req.method === "GET") {
182+
res.writeHead(200, {
183+
"Content-Type": "text/event-stream",
184+
"Cache-Control": "no-cache",
185+
Connection: "keep-alive",
186+
});
187+
res.write("event: endpoint\n");
188+
res.write(`data: ${baseUrl.href}\n\n`);
189+
} else {
190+
res.writeHead(500);
191+
res.end("Internal error");
192+
}
193+
});
194+
195+
await new Promise<void>((resolve) => {
196+
server.listen(0, "127.0.0.1", () => {
197+
const addr = server.address() as AddressInfo;
198+
baseUrl = new URL(`http://127.0.0.1:${addr.port}`);
199+
resolve();
200+
});
201+
});
202+
203+
transport = new SSEClientTransport(baseUrl);
204+
await transport.start();
205+
206+
const testMessage: JSONRPCMessage = {
207+
jsonrpc: "2.0",
208+
id: "test-1",
209+
method: "test",
210+
params: {},
211+
};
212+
213+
await expect(transport.send(testMessage)).rejects.toThrow(/500/);
214+
});
215+
});
216+
217+
describe("header handling", () => {
218+
it("uses custom fetch implementation from EventSourceInit to add auth headers", async () => {
219+
const authToken = "Bearer test-token";
220+
221+
// Create a fetch wrapper that adds auth header
222+
const fetchWithAuth = (url: string | URL, init?: RequestInit) => {
223+
const headers = new Headers(init?.headers);
224+
headers.set("Authorization", authToken);
225+
return fetch(url.toString(), { ...init, headers });
226+
};
227+
228+
transport = new SSEClientTransport(baseUrl, {
229+
eventSourceInit: {
230+
fetch: fetchWithAuth,
231+
},
232+
});
233+
234+
await transport.start();
235+
236+
// Verify the auth header was received by the server
237+
expect(lastServerRequest.headers.authorization).toBe(authToken);
238+
});
239+
240+
it("passes custom headers to fetch requests", async () => {
241+
const customHeaders = {
242+
Authorization: "Bearer test-token",
243+
"X-Custom-Header": "custom-value",
244+
};
245+
246+
transport = new SSEClientTransport(baseUrl, {
247+
requestInit: {
248+
headers: customHeaders,
249+
},
250+
});
251+
252+
await transport.start();
253+
254+
// Mock fetch for the message sending test
255+
global.fetch = jest.fn().mockResolvedValue({
256+
ok: true,
257+
});
258+
259+
const message: JSONRPCMessage = {
260+
jsonrpc: "2.0",
261+
id: "1",
262+
method: "test",
263+
params: {},
264+
};
265+
266+
await transport.send(message);
267+
268+
// Verify fetch was called with correct headers
269+
expect(global.fetch).toHaveBeenCalledWith(
270+
expect.any(URL),
271+
expect.objectContaining({
272+
headers: expect.any(Headers),
273+
}),
274+
);
275+
276+
const calledHeaders = (global.fetch as jest.Mock).mock.calls[0][1]
277+
.headers;
278+
expect(calledHeaders.get("Authorization")).toBe(
279+
customHeaders.Authorization,
280+
);
281+
expect(calledHeaders.get("X-Custom-Header")).toBe(
282+
customHeaders["X-Custom-Header"],
283+
);
284+
expect(calledHeaders.get("content-type")).toBe("application/json");
285+
});
286+
});
287+
});

0 commit comments

Comments
 (0)