Skip to content

Commit 41e0d26

Browse files
authoredFeb 17, 2025··
Merge pull request #144 from modelcontextprotocol/justin/client-auth
Client implementation of MCP auth
2 parents 423b62b + b1f4b65 commit 41e0d26

10 files changed

+1502
-37
lines changed
 

Diff for: ‎jest.config.js

+4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ export default {
77
...defaultEsmPreset,
88
moduleNameMapper: {
99
"^(\\.{1,2}/.*)\\.js$": "$1",
10+
"^pkce-challenge$": "<rootDir>/src/__mocks__/pkce-challenge.ts"
1011
},
12+
transformIgnorePatterns: [
13+
"/node_modules/(?!eventsource)/"
14+
],
1115
testPathIgnorePatterns: ["/node_modules/", "/dist/"],
1216
};

Diff for: ‎package-lock.json

+12-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: ‎package.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
"dependencies": {
4949
"content-type": "^1.0.5",
5050
"eventsource": "^3.0.2",
51+
"pkce-challenge": "^4.1.0",
5152
"raw-body": "^3.0.0",
5253
"zod": "^3.23.8",
5354
"zod-to-json-schema": "^3.24.1"
@@ -73,4 +74,4 @@
7374
"resolutions": {
7475
"strip-ansi": "6.0.1"
7576
}
76-
}
77+
}

Diff for: ‎src/__mocks__/pkce-challenge.ts

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
export default function pkceChallenge() {
2+
return {
3+
code_verifier: "test_verifier",
4+
code_challenge: "test_challenge",
5+
};
6+
}

Diff for: ‎src/client/auth.test.ts

+420
Large diffs are not rendered by default.

Diff for: ‎src/client/auth.ts

+477
Large diffs are not rendered by default.

Diff for: ‎src/client/sse.test.ts

+458-23
Large diffs are not rendered by default.

Diff for: ‎src/client/sse.ts

+121-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { EventSource, type ErrorEvent, type EventSourceInit } from "eventsource";
22
import { Transport } from "../shared/transport.js";
33
import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
4+
import { auth, AuthResult, OAuthClientProvider, UnauthorizedError } from "./auth.js";
45

56
export class SseError extends Error {
67
constructor(
@@ -12,6 +13,42 @@ export class SseError extends Error {
1213
}
1314
}
1415

16+
/**
17+
* Configuration options for the `SSEClientTransport`.
18+
*/
19+
export type SSEClientTransportOptions = {
20+
/**
21+
* An OAuth client provider to use for authentication.
22+
*
23+
* When an `authProvider` is specified and the SSE connection is started:
24+
* 1. The connection is attempted with any existing access token from the `authProvider`.
25+
* 2. If the access token has expired, the `authProvider` is used to refresh the token.
26+
* 3. If token refresh fails or no access token exists, and auth is required, `OAuthClientProvider.redirectToAuthorization` is called, and an `UnauthorizedError` will be thrown from `connect`/`start`.
27+
*
28+
* After the user has finished authorizing via their user agent, and is redirected back to the MCP client application, call `SSEClientTransport.finishAuth` with the authorization code before retrying the connection.
29+
*
30+
* If an `authProvider` is not provided, and auth is required, an `UnauthorizedError` will be thrown.
31+
*
32+
* `UnauthorizedError` might also be thrown when sending any message over the SSE transport, indicating that the session has expired, and needs to be re-authed and reconnected.
33+
*/
34+
authProvider?: OAuthClientProvider;
35+
36+
/**
37+
* Customizes the initial SSE request to the server (the request that begins the stream).
38+
*
39+
* NOTE: Setting this property will prevent an `Authorization` header from
40+
* being automatically attached to the SSE request, if an `authProvider` is
41+
* also given. This can be worked around by setting the `Authorization` header
42+
* manually.
43+
*/
44+
eventSourceInit?: EventSourceInit;
45+
46+
/**
47+
* Customizes recurring POST requests to the server.
48+
*/
49+
requestInit?: RequestInit;
50+
};
51+
1552
/**
1653
* Client transport for SSE: this will connect to a server using Server-Sent Events for receiving
1754
* messages and make separate POST requests for sending messages.
@@ -23,35 +60,76 @@ export class SSEClientTransport implements Transport {
2360
private _url: URL;
2461
private _eventSourceInit?: EventSourceInit;
2562
private _requestInit?: RequestInit;
63+
private _authProvider?: OAuthClientProvider;
2664

2765
onclose?: () => void;
2866
onerror?: (error: Error) => void;
2967
onmessage?: (message: JSONRPCMessage) => void;
3068

3169
constructor(
3270
url: URL,
33-
opts?: { eventSourceInit?: EventSourceInit; requestInit?: RequestInit },
71+
opts?: SSEClientTransportOptions,
3472
) {
3573
this._url = url;
3674
this._eventSourceInit = opts?.eventSourceInit;
3775
this._requestInit = opts?.requestInit;
76+
this._authProvider = opts?.authProvider;
3877
}
3978

40-
start(): Promise<void> {
41-
if (this._eventSource) {
42-
throw new Error(
43-
"SSEClientTransport already started! If using Client class, note that connect() calls start() automatically.",
44-
);
79+
private async _authThenStart(): Promise<void> {
80+
if (!this._authProvider) {
81+
throw new UnauthorizedError("No auth provider");
82+
}
83+
84+
let result: AuthResult;
85+
try {
86+
result = await auth(this._authProvider, { serverUrl: this._url });
87+
} catch (error) {
88+
this.onerror?.(error as Error);
89+
throw error;
90+
}
91+
92+
if (result !== "AUTHORIZED") {
93+
throw new UnauthorizedError();
4594
}
4695

96+
return await this._startOrAuth();
97+
}
98+
99+
private async _commonHeaders(): Promise<HeadersInit> {
100+
const headers: HeadersInit = {};
101+
if (this._authProvider) {
102+
const tokens = await this._authProvider.tokens();
103+
if (tokens) {
104+
headers["Authorization"] = `Bearer ${tokens.access_token}`;
105+
}
106+
}
107+
108+
return headers;
109+
}
110+
111+
private _startOrAuth(): Promise<void> {
47112
return new Promise((resolve, reject) => {
48113
this._eventSource = new EventSource(
49114
this._url.href,
50-
this._eventSourceInit,
115+
this._eventSourceInit ?? {
116+
fetch: (url, init) => this._commonHeaders().then((headers) => fetch(url, {
117+
...init,
118+
headers: {
119+
...headers,
120+
Accept: "text/event-stream"
121+
}
122+
})),
123+
},
51124
);
52125
this._abortController = new AbortController();
53126

54127
this._eventSource.onerror = (event) => {
128+
if (event.code === 401 && this._authProvider) {
129+
this._authThenStart().then(resolve, reject);
130+
return;
131+
}
132+
55133
const error = new SseError(event.code, event.message, event);
56134
reject(error);
57135
this.onerror?.(error);
@@ -97,6 +175,30 @@ export class SSEClientTransport implements Transport {
97175
});
98176
}
99177

178+
async start() {
179+
if (this._eventSource) {
180+
throw new Error(
181+
"SSEClientTransport already started! If using Client class, note that connect() calls start() automatically.",
182+
);
183+
}
184+
185+
return await this._startOrAuth();
186+
}
187+
188+
/**
189+
* Call this method after the user has finished authorizing via their user agent and is redirected back to the MCP client application. This will exchange the authorization code for an access token, enabling the next connection attempt to successfully auth.
190+
*/
191+
async finishAuth(authorizationCode: string): Promise<void> {
192+
if (!this._authProvider) {
193+
throw new UnauthorizedError("No auth provider");
194+
}
195+
196+
const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode });
197+
if (result !== "AUTHORIZED") {
198+
throw new UnauthorizedError("Failed to authorize");
199+
}
200+
}
201+
100202
async close(): Promise<void> {
101203
this._abortController?.abort();
102204
this._eventSource?.close();
@@ -109,7 +211,8 @@ export class SSEClientTransport implements Transport {
109211
}
110212

111213
try {
112-
const headers = new Headers(this._requestInit?.headers);
214+
const commonHeaders = await this._commonHeaders();
215+
const headers = new Headers({ ...commonHeaders, ...this._requestInit?.headers });
113216
headers.set("content-type", "application/json");
114217
const init = {
115218
...this._requestInit,
@@ -120,8 +223,17 @@ export class SSEClientTransport implements Transport {
120223
};
121224

122225
const response = await fetch(this._endpoint, init);
123-
124226
if (!response.ok) {
227+
if (response.status === 401 && this._authProvider) {
228+
const result = await auth(this._authProvider, { serverUrl: this._url });
229+
if (result !== "AUTHORIZED") {
230+
throw new UnauthorizedError();
231+
}
232+
233+
// Purposely _not_ awaited, so we don't call onerror twice
234+
return this.send(message);
235+
}
236+
125237
const text = await response.text().catch(() => null);
126238
throw new Error(
127239
`Error POSTing to endpoint (HTTP ${response.status}): ${text}`,

Diff for: ‎tsconfig.cjs.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
"moduleResolution": "node",
66
"outDir": "./dist/cjs"
77
},
8-
"exclude": ["**/*.test.ts"]
8+
"exclude": ["**/*.test.ts", "src/__mocks__/**/*"]
99
}

Diff for: ‎tsconfig.prod.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
"compilerOptions": {
44
"outDir": "./dist/esm"
55
},
6-
"exclude": ["**/*.test.ts"]
6+
"exclude": ["**/*.test.ts", "src/__mocks__/**/*"]
77
}

0 commit comments

Comments
 (0)
Please sign in to comment.