diff --git a/src/client/streamableHttp.test.ts b/src/client/streamableHttp.test.ts index f748a2be..0b85da17 100644 --- a/src/client/streamableHttp.test.ts +++ b/src/client/streamableHttp.test.ts @@ -1,6 +1,7 @@ import { StreamableHTTPClientTransport, StreamableHTTPReconnectionOptions } from "./streamableHttp.js"; -import { OAuthClientProvider, UnauthorizedError } from "./auth.js"; -import { JSONRPCMessage } from "../types.js"; +import { JSONRPCMessage, LATEST_PROTOCOL_VERSION } from "../types.js"; +import { OAuthClientProvider } from "./auth.js"; +import { OAuthTokens } from "src/shared/auth.js"; describe("StreamableHTTPClientTransport", () => { @@ -517,19 +518,117 @@ describe("StreamableHTTPClientTransport", () => { id: "test-id" }; + const clientInfo = { + "issuer": "http://localhost:1234", + "authorization_endpoint": "http://localhost:1234/authorize", + "token_endpoint": "http://localhost:1234/token", + "revocation_endpoint": "http://localhost:1234/revoke", + "scopes_supported": [ + 'wow', + ], + "grant_types_supported": [ + "authorization_code", + "refresh_token" + ], + "token_endpoint_auth_methods_supported": [ + "client_secret_basic", + "client_secret_post" + ], + "code_challenge_methods_supported": [ + "S256" + ], + "registration_endpoint": "http://localhost:1234/register", + "response_types_supported": [ + "code" + ], + "response_modes_supported": [ + "query", + "fragment" + ] + }; + (global.fetch as jest.Mock) - .mockResolvedValueOnce({ - ok: false, + .mockResolvedValueOnce(new Response("{}", { status: 401, - statusText: "Unauthorized", headers: new Headers() + })) + .mockImplementationOnce(async (url: URL | string, _init?: RequestInit) => { + expect((url as URL).pathname).toBe('/.well-known/oauth-authorization-server') + return new Response(JSON.stringify(clientInfo), { + status: 200, + headers: new Headers({ 'content-type': 'application/json' }) + }) }) - .mockResolvedValue({ - ok: false, - status: 404 - }); + .mockImplementationOnce(async (url: URL | string, _init?: RequestInit) => { + expect((url as URL).pathname).toBe('/.well-known/oauth-authorization-server') + return new Response(JSON.stringify(clientInfo), { + status: 200, + headers: new Headers({ 'content-type': 'application/json' }) + }) + }) + .mockImplementationOnce(async (url: URL | string, init?: RequestInit) => { + expect(String(url)).toBe(clientInfo.token_endpoint) + expect(init).toBeDefined() + expect(init!.body).toBeDefined() + expect(new URLSearchParams(init!.body! as string).get('code')).toBe('any code') + return new Response(JSON.stringify({ + "access_token": "anything", + "token_type": "Bearer", + "expires_at": new Date(Date.now() + 5000), + "scope": "anything", + "refresh_token": "something else" + }), { + status: 200, + headers: new Headers({ 'content-type': 'application/json' }) + }) + }) + .mockImplementationOnce(async (url: URL | string, init?: RequestInit) => { + expect((url as URL).pathname).toBe('/mcp') + expect(init).toBeDefined() + expect(init!.body).toBeDefined() + expect(init!.headers).toBeDefined() + expect(new Headers(init!.headers).get('authorization')).toBe(`Bearer anything`) + const body = JSON.parse(init!.body! as string) + return new Response(JSON.stringify({ + jsonrpc: '2.0', + id: body.id, + result: { + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: {}, + serverInfo: { + name: "test", + version: "1.0", + }, + }, + }), { + status: 200, + headers: new Headers({ 'content-type': 'application/json' }) + }) + }) + - await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); - expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1); + let tokens: OAuthTokens + mockAuthProvider.tokens = jest.fn(() => { + return tokens! + }) + mockAuthProvider.saveTokens = jest.fn((t: OAuthTokens) => { + tokens = t + }) + + mockAuthProvider.redirectToAuthorization = jest.fn(async (redirectUrl: URL) => { + expect(redirectUrl.searchParams.get('response_type')).toBe('code') + expect(redirectUrl.searchParams.get('code_challenge_method')).toBe('S256') + expect(redirectUrl.searchParams.get('code_challenge')).toBe('test_challenge') + expect(redirectUrl.searchParams.get('client_id')).toBe('test-client-id') + expect(redirectUrl.searchParams.get('redirect_uri')).toBe('http://localhost/callback') + expect(redirectUrl.pathname).toBe('/authorize') + + await transport.finishAuth('any code') + }) + + await transport.send(message) + expect(mockAuthProvider.redirectToAuthorization.mock.calls.length).toBe(1) + expect(mockAuthProvider.saveTokens.mock.calls.length).toBe(1) + expect(mockAuthProvider.tokens.mock.calls.length).toBe(3) }); }); diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index 3462b2ab..111ee705 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -401,10 +401,8 @@ export class StreamableHTTPClientTransport implements Transport { if (!response.ok) { if (response.status === 401 && this._authProvider) { - const result = await auth(this._authProvider, { serverUrl: this._url }); - if (result !== "AUTHORIZED") { - throw new UnauthorizedError(); - } + // Whether this is REDIRECT or AUTHORIZED, retry sending the message. + await auth(this._authProvider, { serverUrl: this._url }); // Purposely _not_ awaited, so we don't call onerror twice return this.send(message);