Skip to content

Commit 5b9de6a

Browse files
committed
StreamableHTTPClientTransport cleanup / fixes
* always send headers specified in requestInit option * avoid doubled onerror call * use for-await to iterate SSE stream * remove outdated comments * simplify requestId tracking * throw error when response Content-Type is out of spec
1 parent 1aad768 commit 5b9de6a

File tree

2 files changed

+111
-64
lines changed

2 files changed

+111
-64
lines changed

src/client/streamableHttp.test.ts

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ describe("StreamableHTTPClientTransport", () => {
8080
(global.fetch as jest.Mock).mockResolvedValueOnce({
8181
ok: true,
8282
status: 200,
83-
headers: new Headers({ "mcp-session-id": "test-session-id" }),
83+
headers: new Headers({ "content-type": "text/event-stream", "mcp-session-id": "test-session-id" }),
8484
});
8585

8686
await transport.send(message);
@@ -237,7 +237,7 @@ describe("StreamableHTTPClientTransport", () => {
237237

238238
(global.fetch as jest.Mock)
239239
.mockResolvedValueOnce({
240-
ok: true,
240+
ok: true,
241241
status: 200,
242242
headers: new Headers({ "content-type": "text/event-stream" }),
243243
body: makeStream("request1")
@@ -263,13 +263,13 @@ describe("StreamableHTTPClientTransport", () => {
263263

264264
// Both streams should have delivered their messages
265265
expect(messageSpy).toHaveBeenCalledTimes(2);
266-
266+
267267
// Verify received messages without assuming specific order
268268
expect(messageSpy.mock.calls.some(call => {
269269
const msg = call[0];
270270
return msg.id === "request1" && msg.result?.id === "request1";
271271
})).toBe(true);
272-
272+
273273
expect(messageSpy.mock.calls.some(call => {
274274
const msg = call[0];
275275
return msg.id === "request2" && msg.result?.id === "request2";
@@ -313,4 +313,67 @@ describe("StreamableHTTPClientTransport", () => {
313313
const lastCall = calls[calls.length - 1];
314314
expect(lastCall[1].headers.get("last-event-id")).toBe("event-123");
315315
});
316-
});
316+
317+
it("should throw error when invalid content-type is received", async () => {
318+
const message: JSONRPCMessage = {
319+
jsonrpc: "2.0",
320+
method: "test",
321+
params: {},
322+
id: "test-id"
323+
};
324+
325+
const stream = new ReadableStream({
326+
start(controller) {
327+
controller.enqueue('invalid text response');
328+
controller.close();
329+
}
330+
});
331+
332+
const errorSpy = jest.fn();
333+
transport.onerror = errorSpy;
334+
335+
(global.fetch as jest.Mock).mockResolvedValueOnce({
336+
ok: true,
337+
status: 200,
338+
headers: new Headers({ "content-type": "text/plain" }),
339+
body: stream
340+
});
341+
342+
await transport.start();
343+
await expect(transport.send(message)).rejects.toThrow("Unexpected content type: text/plain");
344+
expect(errorSpy).toHaveBeenCalled();
345+
});
346+
347+
348+
it("should always send specified custom headers", async () => {
349+
const requestInit = {
350+
headers: {
351+
'X-Custom-Header': 'CustomValue'
352+
}
353+
};
354+
transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), {
355+
requestInit: requestInit
356+
});
357+
358+
let actualReqInit: RequestInit = {};
359+
360+
((global.fetch as jest.Mock)).mockImplementation(
361+
async (_url, reqInit) => {
362+
actualReqInit = reqInit;
363+
return new Response(null, {status: 200, headers: {'content-type': 'text/event-stream'}});
364+
}
365+
);
366+
367+
await transport.start();
368+
369+
await transport.openSseStream();
370+
expect((actualReqInit.headers as Headers).get('x-custom-header')).toBe('CustomValue');
371+
372+
requestInit.headers['X-Custom-Header'] = 'SecondCustomValue';
373+
374+
await transport.send({jsonrpc: '2.0', method: 'test', params: {}} as JSONRPCMessage);
375+
expect((actualReqInit.headers as Headers).get('x-custom-header')).toBe('SecondCustomValue');
376+
377+
expect(global.fetch).toHaveBeenCalledTimes(2);
378+
})
379+
});

src/client/streamableHttp.ts

Lines changed: 43 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import { Transport } from "../shared/transport.js";
22
import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
33
import { auth, AuthResult, OAuthClientProvider, UnauthorizedError } from "./auth.js";
4-
import { EventSourceParserStream } from 'eventsource-parser/stream';
4+
import { EventSourceParserStream } from "eventsource-parser/stream";
5+
56
export class StreamableHTTPError extends Error {
67
constructor(
78
public readonly code: number | undefined,
@@ -17,16 +18,16 @@ export class StreamableHTTPError extends Error {
1718
export type StreamableHTTPClientTransportOptions = {
1819
/**
1920
* An OAuth client provider to use for authentication.
20-
*
21+
*
2122
* When an `authProvider` is specified and the connection is started:
2223
* 1. The connection is attempted with any existing access token from the `authProvider`.
2324
* 2. If the access token has expired, the `authProvider` is used to refresh the token.
2425
* 3. If token refresh fails or no access token exists, and auth is required, `OAuthClientProvider.redirectToAuthorization` is called, and an `UnauthorizedError` will be thrown from `connect`/`start`.
25-
*
26+
*
2627
* After the user has finished authorizing via their user agent, and is redirected back to the MCP client application, call `StreamableHTTPClientTransport.finishAuth` with the authorization code before retrying the connection.
27-
*
28+
*
2829
* If an `authProvider` is not provided, and auth is required, an `UnauthorizedError` will be thrown.
29-
*
30+
*
3031
* `UnauthorizedError` might also be thrown when sending any message over the transport, indicating that the session has expired, and needs to be re-authed and reconnected.
3132
*/
3233
authProvider?: OAuthClientProvider;
@@ -70,7 +71,7 @@ export class StreamableHTTPClientTransport implements Transport {
7071

7172
let result: AuthResult;
7273
try {
73-
result = await auth(this._authProvider, { serverUrl: this._url });
74+
result = await auth(this._authProvider, {serverUrl: this._url});
7475
} catch (error) {
7576
this.onerror?.(error as Error);
7677
throw error;
@@ -83,7 +84,7 @@ export class StreamableHTTPClientTransport implements Transport {
8384
return await this._startOrAuthStandaloneSSE();
8485
}
8586

86-
private async _commonHeaders(): Promise<HeadersInit> {
87+
private async _commonHeaders(): Promise<Headers> {
8788
const headers: HeadersInit = {};
8889
if (this._authProvider) {
8990
const tokens = await this._authProvider.tokens();
@@ -96,24 +97,25 @@ export class StreamableHTTPClientTransport implements Transport {
9697
headers["mcp-session-id"] = this._sessionId;
9798
}
9899

99-
return headers;
100+
return new Headers(
101+
{...headers, ...this._requestInit?.headers}
102+
);
100103
}
101104

102105
private async _startOrAuthStandaloneSSE(): Promise<void> {
103106
try {
104107
// Try to open an initial SSE stream with GET to listen for server messages
105108
// This is optional according to the spec - server may not support it
106-
const commonHeaders = await this._commonHeaders();
107-
const headers = new Headers(commonHeaders);
108-
headers.set('Accept', 'text/event-stream');
109+
const headers = await this._commonHeaders();
110+
headers.set("Accept", "text/event-stream");
109111

110112
// Include Last-Event-ID header for resumable streams
111113
if (this._lastEventId) {
112-
headers.set('last-event-id', this._lastEventId);
114+
headers.set("last-event-id", this._lastEventId);
113115
}
114116

115117
const response = await fetch(this._url, {
116-
method: 'GET',
118+
method: "GET",
117119
headers,
118120
signal: this._abortController?.signal,
119121
});
@@ -124,12 +126,10 @@ export class StreamableHTTPClientTransport implements Transport {
124126
return await this._authThenStart();
125127
}
126128

127-
const error = new StreamableHTTPError(
129+
throw new StreamableHTTPError(
128130
response.status,
129131
`Failed to open SSE stream: ${response.statusText}`,
130132
);
131-
this.onerror?.(error);
132-
throw error;
133133
}
134134

135135
// Successful connection, handle the SSE stream as a standalone listener
@@ -144,42 +144,29 @@ export class StreamableHTTPClientTransport implements Transport {
144144
if (!stream) {
145145
return;
146146
}
147-
// Create a pipeline: binary stream -> text decoder -> SSE parser
148-
const eventStream = stream
149-
.pipeThrough(new TextDecoderStream())
150-
.pipeThrough(new EventSourceParserStream());
151147

152-
const reader = eventStream.getReader();
153148
const processStream = async () => {
154-
try {
155-
while (true) {
156-
const { done, value: event } = await reader.read();
157-
if (done) {
158-
break;
159-
}
160-
161-
// Update last event ID if provided
162-
if (event.id) {
163-
this._lastEventId = event.id;
164-
}
165-
166-
// Handle message events (default event type is undefined per docs)
167-
// or explicit 'message' event type
168-
if (!event.event || event.event === 'message') {
169-
try {
170-
const message = JSONRPCMessageSchema.parse(JSON.parse(event.data));
171-
this.onmessage?.(message);
172-
} catch (error) {
173-
this.onerror?.(error as Error);
174-
}
149+
// Create a pipeline: binary stream -> text decoder -> SSE parser
150+
const eventStream = stream
151+
.pipeThrough(new TextDecoderStream())
152+
.pipeThrough(new EventSourceParserStream());
153+
154+
for await (const event of eventStream) {
155+
if (event.id) {
156+
this._lastEventId = event.id;
157+
}
158+
if (!event.event || event.event === "message") {
159+
try {
160+
const message = JSONRPCMessageSchema.parse(JSON.parse(event.data));
161+
this.onmessage?.(message);
162+
} catch (error) {
163+
this.onerror?.(error as Error);
175164
}
176165
}
177-
} catch (error) {
178-
this.onerror?.(error as Error);
179166
}
180167
};
181168

182-
processStream();
169+
processStream().catch(err => this.onerror?.(err));
183170
}
184171

185172
async start() {
@@ -200,7 +187,7 @@ export class StreamableHTTPClientTransport implements Transport {
200187
throw new UnauthorizedError("No auth provider");
201188
}
202189

203-
const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode });
190+
const result = await auth(this._authProvider, {serverUrl: this._url, authorizationCode});
204191
if (result !== "AUTHORIZED") {
205192
throw new UnauthorizedError("Failed to authorize");
206193
}
@@ -215,8 +202,7 @@ export class StreamableHTTPClientTransport implements Transport {
215202

216203
async send(message: JSONRPCMessage | JSONRPCMessage[]): Promise<void> {
217204
try {
218-
const commonHeaders = await this._commonHeaders();
219-
const headers = new Headers({ ...commonHeaders, ...this._requestInit?.headers });
205+
const headers = await this._commonHeaders();
220206
headers.set("content-type", "application/json");
221207
headers.set("accept", "application/json, text/event-stream");
222208

@@ -238,7 +224,7 @@ export class StreamableHTTPClientTransport implements Transport {
238224

239225
if (!response.ok) {
240226
if (response.status === 401 && this._authProvider) {
241-
const result = await auth(this._authProvider, { serverUrl: this._url });
227+
const result = await auth(this._authProvider, {serverUrl: this._url});
242228
if (result !== "AUTHORIZED") {
243229
throw new UnauthorizedError();
244230
}
@@ -261,20 +247,13 @@ export class StreamableHTTPClientTransport implements Transport {
261247
// Get original message(s) for detecting request IDs
262248
const messages = Array.isArray(message) ? message : [message];
263249

264-
// Extract IDs from request messages for tracking responses
265-
const requestIds = messages.filter(msg => 'method' in msg && 'id' in msg)
266-
.map(msg => 'id' in msg ? msg.id : undefined)
267-
.filter(id => id !== undefined);
268-
269-
// If we have request IDs and an SSE response, create a unique stream ID
270-
const hasRequests = requestIds.length > 0;
250+
const hasRequests = messages.filter(msg => "method" in msg && "id" in msg && msg.id !== undefined).length > 0;
271251

272252
// Check the response type
273253
const contentType = response.headers.get("content-type");
274254

275255
if (hasRequests) {
276256
if (contentType?.includes("text/event-stream")) {
277-
// For streaming responses, create a unique stream ID based on request IDs
278257
this._handleSseStream(response.body);
279258
} else if (contentType?.includes("application/json")) {
280259
// For non-streaming servers, we might get direct JSON responses
@@ -286,6 +265,11 @@ export class StreamableHTTPClientTransport implements Transport {
286265
for (const msg of responseMessages) {
287266
this.onmessage?.(msg);
288267
}
268+
} else {
269+
throw new StreamableHTTPError(
270+
-1,
271+
`Unexpected content type: ${contentType}`,
272+
);
289273
}
290274
}
291275
} catch (error) {
@@ -296,7 +280,7 @@ export class StreamableHTTPClientTransport implements Transport {
296280

297281
/**
298282
* Opens SSE stream to receive messages from the server.
299-
*
283+
*
300284
* This allows the server to push messages to the client without requiring the client
301285
* to first send a request via HTTP POST. Some servers may not support this feature.
302286
* If authentication is required but fails, this method will throw an UnauthorizedError.
@@ -309,4 +293,4 @@ export class StreamableHTTPClientTransport implements Transport {
309293
}
310294
await this._startOrAuthStandaloneSSE();
311295
}
312-
}
296+
}

0 commit comments

Comments
 (0)