Skip to content

Commit 77352db

Browse files
authored
Merge pull request #300 from prompteus-ai/streamable-http-client-fixes
StreamableHTTPClientTransport cleanup / fixes
2 parents 632b836 + a76004c commit 77352db

File tree

2 files changed

+114
-64
lines changed

2 files changed

+114
-64
lines changed

src/client/streamableHttp.test.ts

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ describe("StreamableHTTPClientTransport", () => {
8080
(global.fetch as jest.Mock).mockResolvedValueOnce({
8181
ok: true,
8282
status: 200,
83-
headers: new Headers({ "mcp-session-id": "test-session-id" }),
83+
headers: new Headers({ "content-type": "text/event-stream", "mcp-session-id": "test-session-id" }),
8484
});
8585

8686
await transport.send(message);
@@ -164,7 +164,7 @@ describe("StreamableHTTPClientTransport", () => {
164164
// We expect the 405 error to be caught and handled gracefully
165165
// This should not throw an error that breaks the transport
166166
await transport.start();
167-
await expect(transport.openSseStream()).rejects.toThrow('Failed to open SSE stream: Method Not Allowed');
167+
await expect(transport.openSseStream()).rejects.toThrow("Failed to open SSE stream: Method Not Allowed");
168168

169169
// Check that GET was attempted
170170
expect(global.fetch).toHaveBeenCalledWith(
@@ -192,7 +192,7 @@ describe("StreamableHTTPClientTransport", () => {
192192
const stream = new ReadableStream({
193193
start(controller) {
194194
// Send a server notification via SSE
195-
const event = 'event: message\ndata: {"jsonrpc": "2.0", "method": "serverNotification", "params": {}}\n\n';
195+
const event = "event: message\ndata: {\"jsonrpc\": \"2.0\", \"method\": \"serverNotification\", \"params\": {}}\n\n";
196196
controller.enqueue(encoder.encode(event));
197197
}
198198
});
@@ -237,7 +237,7 @@ describe("StreamableHTTPClientTransport", () => {
237237

238238
(global.fetch as jest.Mock)
239239
.mockResolvedValueOnce({
240-
ok: true,
240+
ok: true,
241241
status: 200,
242242
headers: new Headers({ "content-type": "text/event-stream" }),
243243
body: makeStream("request1")
@@ -263,13 +263,13 @@ describe("StreamableHTTPClientTransport", () => {
263263

264264
// Both streams should have delivered their messages
265265
expect(messageSpy).toHaveBeenCalledTimes(2);
266-
266+
267267
// Verify received messages without assuming specific order
268268
expect(messageSpy.mock.calls.some(call => {
269269
const msg = call[0];
270270
return msg.id === "request1" && msg.result?.id === "request1";
271271
})).toBe(true);
272-
272+
273273
expect(messageSpy.mock.calls.some(call => {
274274
const msg = call[0];
275275
return msg.id === "request2" && msg.result?.id === "request2";
@@ -281,7 +281,7 @@ describe("StreamableHTTPClientTransport", () => {
281281
const encoder = new TextEncoder();
282282
const stream = new ReadableStream({
283283
start(controller) {
284-
const event = 'id: event-123\nevent: message\ndata: {"jsonrpc": "2.0", "method": "serverNotification", "params": {}}\n\n';
284+
const event = "id: event-123\nevent: message\ndata: {\"jsonrpc\": \"2.0\", \"method\": \"serverNotification\", \"params\": {}}\n\n";
285285
controller.enqueue(encoder.encode(event));
286286
controller.close();
287287
}
@@ -313,4 +313,67 @@ describe("StreamableHTTPClientTransport", () => {
313313
const lastCall = calls[calls.length - 1];
314314
expect(lastCall[1].headers.get("last-event-id")).toBe("event-123");
315315
});
316-
});
316+
317+
it("should throw error when invalid content-type is received", async () => {
318+
const message: JSONRPCMessage = {
319+
jsonrpc: "2.0",
320+
method: "test",
321+
params: {},
322+
id: "test-id"
323+
};
324+
325+
const stream = new ReadableStream({
326+
start(controller) {
327+
controller.enqueue("invalid text response");
328+
controller.close();
329+
}
330+
});
331+
332+
const errorSpy = jest.fn();
333+
transport.onerror = errorSpy;
334+
335+
(global.fetch as jest.Mock).mockResolvedValueOnce({
336+
ok: true,
337+
status: 200,
338+
headers: new Headers({ "content-type": "text/plain" }),
339+
body: stream
340+
});
341+
342+
await transport.start();
343+
await expect(transport.send(message)).rejects.toThrow("Unexpected content type: text/plain");
344+
expect(errorSpy).toHaveBeenCalled();
345+
});
346+
347+
348+
it("should always send specified custom headers", async () => {
349+
const requestInit = {
350+
headers: {
351+
"X-Custom-Header": "CustomValue"
352+
}
353+
};
354+
transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), {
355+
requestInit: requestInit
356+
});
357+
358+
let actualReqInit: RequestInit = {};
359+
360+
((global.fetch as jest.Mock)).mockImplementation(
361+
async (_url, reqInit) => {
362+
actualReqInit = reqInit;
363+
return new Response(null, { status: 200, headers: { "content-type": "text/event-stream" } });
364+
}
365+
);
366+
367+
await transport.start();
368+
369+
await transport.openSseStream();
370+
expect((actualReqInit.headers as Headers).get("x-custom-header")).toBe("CustomValue");
371+
372+
requestInit.headers["X-Custom-Header"] = "SecondCustomValue";
373+
374+
await transport.send({ jsonrpc: "2.0", method: "test", params: {} } as JSONRPCMessage);
375+
expect((actualReqInit.headers as Headers).get("x-custom-header")).toBe("SecondCustomValue");
376+
377+
expect(global.fetch).toHaveBeenCalledTimes(2);
378+
});
379+
});

src/client/streamableHttp.ts

Lines changed: 43 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import { Transport } from "../shared/transport.js";
22
import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
33
import { auth, AuthResult, OAuthClientProvider, UnauthorizedError } from "./auth.js";
4-
import { EventSourceParserStream } from 'eventsource-parser/stream';
4+
import { EventSourceParserStream } from "eventsource-parser/stream";
5+
56
export class StreamableHTTPError extends Error {
67
constructor(
78
public readonly code: number | undefined,
@@ -17,16 +18,16 @@ export class StreamableHTTPError extends Error {
1718
export type StreamableHTTPClientTransportOptions = {
1819
/**
1920
* An OAuth client provider to use for authentication.
20-
*
21+
*
2122
* When an `authProvider` is specified and the connection is started:
2223
* 1. The connection is attempted with any existing access token from the `authProvider`.
2324
* 2. If the access token has expired, the `authProvider` is used to refresh the token.
2425
* 3. If token refresh fails or no access token exists, and auth is required, `OAuthClientProvider.redirectToAuthorization` is called, and an `UnauthorizedError` will be thrown from `connect`/`start`.
25-
*
26+
*
2627
* After the user has finished authorizing via their user agent, and is redirected back to the MCP client application, call `StreamableHTTPClientTransport.finishAuth` with the authorization code before retrying the connection.
27-
*
28+
*
2829
* If an `authProvider` is not provided, and auth is required, an `UnauthorizedError` will be thrown.
29-
*
30+
*
3031
* `UnauthorizedError` might also be thrown when sending any message over the transport, indicating that the session has expired, and needs to be re-authed and reconnected.
3132
*/
3233
authProvider?: OAuthClientProvider;
@@ -83,7 +84,7 @@ export class StreamableHTTPClientTransport implements Transport {
8384
return await this._startOrAuthStandaloneSSE();
8485
}
8586

86-
private async _commonHeaders(): Promise<HeadersInit> {
87+
private async _commonHeaders(): Promise<Headers> {
8788
const headers: HeadersInit = {};
8889
if (this._authProvider) {
8990
const tokens = await this._authProvider.tokens();
@@ -96,24 +97,25 @@ export class StreamableHTTPClientTransport implements Transport {
9697
headers["mcp-session-id"] = this._sessionId;
9798
}
9899

99-
return headers;
100+
return new Headers(
101+
{ ...headers, ...this._requestInit?.headers }
102+
);
100103
}
101104

102105
private async _startOrAuthStandaloneSSE(): Promise<void> {
103106
try {
104107
// Try to open an initial SSE stream with GET to listen for server messages
105108
// This is optional according to the spec - server may not support it
106-
const commonHeaders = await this._commonHeaders();
107-
const headers = new Headers(commonHeaders);
108-
headers.set('Accept', 'text/event-stream');
109+
const headers = await this._commonHeaders();
110+
headers.set("Accept", "text/event-stream");
109111

110112
// Include Last-Event-ID header for resumable streams
111113
if (this._lastEventId) {
112-
headers.set('last-event-id', this._lastEventId);
114+
headers.set("last-event-id", this._lastEventId);
113115
}
114116

115117
const response = await fetch(this._url, {
116-
method: 'GET',
118+
method: "GET",
117119
headers,
118120
signal: this._abortController?.signal,
119121
});
@@ -124,12 +126,10 @@ export class StreamableHTTPClientTransport implements Transport {
124126
return await this._authThenStart();
125127
}
126128

127-
const error = new StreamableHTTPError(
129+
throw new StreamableHTTPError(
128130
response.status,
129131
`Failed to open SSE stream: ${response.statusText}`,
130132
);
131-
this.onerror?.(error);
132-
throw error;
133133
}
134134

135135
// Successful connection, handle the SSE stream as a standalone listener
@@ -144,42 +144,32 @@ export class StreamableHTTPClientTransport implements Transport {
144144
if (!stream) {
145145
return;
146146
}
147-
// Create a pipeline: binary stream -> text decoder -> SSE parser
148-
const eventStream = stream
149-
.pipeThrough(new TextDecoderStream())
150-
.pipeThrough(new EventSourceParserStream());
151147

152-
const reader = eventStream.getReader();
153148
const processStream = async () => {
154-
try {
155-
while (true) {
156-
const { done, value: event } = await reader.read();
157-
if (done) {
158-
break;
159-
}
160-
161-
// Update last event ID if provided
162-
if (event.id) {
163-
this._lastEventId = event.id;
164-
}
165-
166-
// Handle message events (default event type is undefined per docs)
167-
// or explicit 'message' event type
168-
if (!event.event || event.event === 'message') {
169-
try {
170-
const message = JSONRPCMessageSchema.parse(JSON.parse(event.data));
171-
this.onmessage?.(message);
172-
} catch (error) {
173-
this.onerror?.(error as Error);
174-
}
149+
// Create a pipeline: binary stream -> text decoder -> SSE parser
150+
const eventStream = stream
151+
.pipeThrough(new TextDecoderStream())
152+
.pipeThrough(new EventSourceParserStream());
153+
154+
for await (const event of eventStream) {
155+
// Update last event ID if provided
156+
if (event.id) {
157+
this._lastEventId = event.id;
158+
}
159+
// Handle message events (default event type is undefined per docs)
160+
// or explicit 'message' event type
161+
if (!event.event || event.event === "message") {
162+
try {
163+
const message = JSONRPCMessageSchema.parse(JSON.parse(event.data));
164+
this.onmessage?.(message);
165+
} catch (error) {
166+
this.onerror?.(error as Error);
175167
}
176168
}
177-
} catch (error) {
178-
this.onerror?.(error as Error);
179169
}
180170
};
181171

182-
processStream();
172+
processStream().catch(err => this.onerror?.(err));
183173
}
184174

185175
async start() {
@@ -215,8 +205,7 @@ export class StreamableHTTPClientTransport implements Transport {
215205

216206
async send(message: JSONRPCMessage | JSONRPCMessage[]): Promise<void> {
217207
try {
218-
const commonHeaders = await this._commonHeaders();
219-
const headers = new Headers({ ...commonHeaders, ...this._requestInit?.headers });
208+
const headers = await this._commonHeaders();
220209
headers.set("content-type", "application/json");
221210
headers.set("accept", "application/json, text/event-stream");
222211

@@ -261,20 +250,13 @@ export class StreamableHTTPClientTransport implements Transport {
261250
// Get original message(s) for detecting request IDs
262251
const messages = Array.isArray(message) ? message : [message];
263252

264-
// Extract IDs from request messages for tracking responses
265-
const requestIds = messages.filter(msg => 'method' in msg && 'id' in msg)
266-
.map(msg => 'id' in msg ? msg.id : undefined)
267-
.filter(id => id !== undefined);
268-
269-
// If we have request IDs and an SSE response, create a unique stream ID
270-
const hasRequests = requestIds.length > 0;
253+
const hasRequests = messages.filter(msg => "method" in msg && "id" in msg && msg.id !== undefined).length > 0;
271254

272255
// Check the response type
273256
const contentType = response.headers.get("content-type");
274257

275258
if (hasRequests) {
276259
if (contentType?.includes("text/event-stream")) {
277-
// For streaming responses, create a unique stream ID based on request IDs
278260
this._handleSseStream(response.body);
279261
} else if (contentType?.includes("application/json")) {
280262
// For non-streaming servers, we might get direct JSON responses
@@ -286,6 +268,11 @@ export class StreamableHTTPClientTransport implements Transport {
286268
for (const msg of responseMessages) {
287269
this.onmessage?.(msg);
288270
}
271+
} else {
272+
throw new StreamableHTTPError(
273+
-1,
274+
`Unexpected content type: ${contentType}`,
275+
);
289276
}
290277
}
291278
} catch (error) {
@@ -296,7 +283,7 @@ export class StreamableHTTPClientTransport implements Transport {
296283

297284
/**
298285
* Opens SSE stream to receive messages from the server.
299-
*
286+
*
300287
* This allows the server to push messages to the client without requiring the client
301288
* to first send a request via HTTP POST. Some servers may not support this feature.
302289
* If authentication is required but fails, this method will throw an UnauthorizedError.
@@ -309,4 +296,4 @@ export class StreamableHTTPClientTransport implements Transport {
309296
}
310297
await this._startOrAuthStandaloneSSE();
311298
}
312-
}
299+
}

0 commit comments

Comments
 (0)