Skip to content

Use eventsource package, to permit custom headers for SSE #134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
},
"dependencies": {
"content-type": "^1.0.5",
"eventsource": "^3.0.2",
"raw-body": "^3.0.0",
"zod": "^3.23.8",
"zod-to-json-schema": "^3.24.1"
Expand All @@ -61,7 +62,6 @@
"@types/node": "^22.0.2",
"@types/ws": "^8.5.12",
"eslint": "^9.8.0",
"eventsource": "^2.0.2",
"express": "^4.19.2",
"jest": "^29.7.0",
"ts-jest": "^29.2.4",
Expand Down
3 changes: 0 additions & 3 deletions src/cli.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import EventSource from "eventsource";
import WebSocket from "ws";

// eslint-disable-next-line @typescript-eslint/no-explicit-any
(global as any).EventSource = EventSource;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(global as any).WebSocket = WebSocket;

Expand Down
287 changes: 287 additions & 0 deletions src/client/sse.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
import { createServer, type IncomingMessage, type Server } from "http";
import { AddressInfo } from "net";
import { JSONRPCMessage } from "../types.js";
import { SSEClientTransport } from "./sse.js";

describe("SSEClientTransport", () => {
let server: Server;
let transport: SSEClientTransport;
let baseUrl: URL;
let lastServerRequest: IncomingMessage;
let sendServerMessage: ((message: string) => void) | null = null;

beforeEach((done) => {
// Reset state
lastServerRequest = null as unknown as IncomingMessage;
sendServerMessage = null;

// Create a test server that will receive the EventSource connection
server = createServer((req, res) => {
lastServerRequest = req;

// Send SSE headers
res.writeHead(200, {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
Connection: "keep-alive",
});

// Send the endpoint event
res.write("event: endpoint\n");
res.write(`data: ${baseUrl.href}\n\n`);

// Store reference to send function for tests
sendServerMessage = (message: string) => {
res.write(`data: ${message}\n\n`);
};

// Handle request body for POST endpoints
if (req.method === "POST") {
let body = "";
req.on("data", (chunk) => {
body += chunk;
});
req.on("end", () => {
(req as IncomingMessage & { body: string }).body = body;
res.end();
});
}
});

// Start server on random port
server.listen(0, "127.0.0.1", () => {
const addr = server.address() as AddressInfo;
baseUrl = new URL(`http://127.0.0.1:${addr.port}`);
done();
});
});

afterEach(async () => {
await transport.close();
await server.close();
});

describe("connection handling", () => {
it("establishes SSE connection and receives endpoint", async () => {
transport = new SSEClientTransport(baseUrl);
await transport.start();

expect(lastServerRequest.headers.accept).toBe("text/event-stream");
expect(lastServerRequest.method).toBe("GET");
});

it("rejects if server returns non-200 status", async () => {
// Create a server that returns 403
server.close();
await new Promise((resolve) => server.on("close", resolve));

server = createServer((req, res) => {
res.writeHead(403);
res.end();
});

await new Promise<void>((resolve) => {
server.listen(0, "127.0.0.1", () => {
const addr = server.address() as AddressInfo;
baseUrl = new URL(`http://127.0.0.1:${addr.port}`);
resolve();
});
});

transport = new SSEClientTransport(baseUrl);
await expect(transport.start()).rejects.toThrow();
});

it("closes EventSource connection on close()", async () => {
transport = new SSEClientTransport(baseUrl);
await transport.start();

const closePromise = new Promise((resolve) => {
lastServerRequest.on("close", resolve);
});

await transport.close();
await closePromise;
});
});

describe("message handling", () => {
it("receives and parses JSON-RPC messages", async () => {
const receivedMessages: JSONRPCMessage[] = [];
transport = new SSEClientTransport(baseUrl);
transport.onmessage = (msg) => receivedMessages.push(msg);

await transport.start();

const testMessage: JSONRPCMessage = {
jsonrpc: "2.0",
id: "test-1",
method: "test",
params: { foo: "bar" },
};

sendServerMessage!(JSON.stringify(testMessage));

// Wait for message processing
await new Promise((resolve) => setTimeout(resolve, 50));

expect(receivedMessages).toHaveLength(1);
expect(receivedMessages[0]).toEqual(testMessage);
});

it("handles malformed JSON messages", async () => {
const errors: Error[] = [];
transport = new SSEClientTransport(baseUrl);
transport.onerror = (err) => errors.push(err);

await transport.start();

sendServerMessage!("invalid json");

// Wait for message processing
await new Promise((resolve) => setTimeout(resolve, 50));

expect(errors).toHaveLength(1);
expect(errors[0].message).toMatch(/JSON/);
});

it("handles messages via POST requests", async () => {
transport = new SSEClientTransport(baseUrl);
await transport.start();

const testMessage: JSONRPCMessage = {
jsonrpc: "2.0",
id: "test-1",
method: "test",
params: { foo: "bar" },
};

await transport.send(testMessage);

// Wait for request processing
await new Promise((resolve) => setTimeout(resolve, 50));

expect(lastServerRequest.method).toBe("POST");
expect(lastServerRequest.headers["content-type"]).toBe(
"application/json",
);
expect(
JSON.parse(
(lastServerRequest as IncomingMessage & { body: string }).body,
),
).toEqual(testMessage);
});

it("handles POST request failures", async () => {
// Create a server that returns 500 for POST
server.close();
await new Promise((resolve) => server.on("close", resolve));

server = createServer((req, res) => {
if (req.method === "GET") {
res.writeHead(200, {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
Connection: "keep-alive",
});
res.write("event: endpoint\n");
res.write(`data: ${baseUrl.href}\n\n`);
} else {
res.writeHead(500);
res.end("Internal error");
}
});

await new Promise<void>((resolve) => {
server.listen(0, "127.0.0.1", () => {
const addr = server.address() as AddressInfo;
baseUrl = new URL(`http://127.0.0.1:${addr.port}`);
resolve();
});
});

transport = new SSEClientTransport(baseUrl);
await transport.start();

const testMessage: JSONRPCMessage = {
jsonrpc: "2.0",
id: "test-1",
method: "test",
params: {},
};

await expect(transport.send(testMessage)).rejects.toThrow(/500/);
});
});

describe("header handling", () => {
it("uses custom fetch implementation from EventSourceInit to add auth headers", async () => {
const authToken = "Bearer test-token";

// Create a fetch wrapper that adds auth header
const fetchWithAuth = (url: string | URL, init?: RequestInit) => {
const headers = new Headers(init?.headers);
headers.set("Authorization", authToken);
return fetch(url.toString(), { ...init, headers });
};

transport = new SSEClientTransport(baseUrl, {
eventSourceInit: {
fetch: fetchWithAuth,
},
});

await transport.start();

// Verify the auth header was received by the server
expect(lastServerRequest.headers.authorization).toBe(authToken);
});

it("passes custom headers to fetch requests", async () => {
const customHeaders = {
Authorization: "Bearer test-token",
"X-Custom-Header": "custom-value",
};

transport = new SSEClientTransport(baseUrl, {
requestInit: {
headers: customHeaders,
},
});

await transport.start();

// Mock fetch for the message sending test
global.fetch = jest.fn().mockResolvedValue({
ok: true,
});

const message: JSONRPCMessage = {
jsonrpc: "2.0",
id: "1",
method: "test",
params: {},
};

await transport.send(message);

// Verify fetch was called with correct headers
expect(global.fetch).toHaveBeenCalledWith(
expect.any(URL),
expect.objectContaining({
headers: expect.any(Headers),
}),
);

const calledHeaders = (global.fetch as jest.Mock).mock.calls[0][1]
.headers;
expect(calledHeaders.get("Authorization")).toBe(
customHeaders.Authorization,
);
expect(calledHeaders.get("X-Custom-Header")).toBe(
customHeaders["X-Custom-Header"],
);
expect(calledHeaders.get("content-type")).toBe("application/json");
});
});
});
Loading
Loading