Skip to content

Commit c0ebe90

Browse files
committed
Test transparent token refresh in SSEClientTransport
1 parent 79d2db3 commit c0ebe90

File tree

1 file changed

+254
-2
lines changed

1 file changed

+254
-2
lines changed

Diff for: src/client/sse.test.ts

+254-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { createServer, type IncomingMessage, type Server } from "http";
22
import { AddressInfo } from "net";
33
import { JSONRPCMessage } from "../types.js";
44
import { SSEClientTransport } from "./sse.js";
5-
import { auth, OAuthClientProvider } from "./auth.js";
5+
import { OAuthClientProvider, OAuthTokens } from "./auth.js";
66

77
describe("SSEClientTransport", () => {
88
let server: Server;
@@ -301,7 +301,7 @@ describe("SSEClientTransport", () => {
301301
mockAuthProvider = {
302302
get redirectUrl() { return "http://localhost/callback"; },
303303
get clientMetadata() { return { redirect_uris: ["http://localhost/callback"] }; },
304-
clientInformation: jest.fn(() => ({ client_id: "test-client-id" })),
304+
clientInformation: jest.fn(() => ({ client_id: "test-client-id", client_secret: "test-client-secret" })),
305305
tokens: jest.fn(),
306306
saveTokens: jest.fn(),
307307
redirectToAuthorization: jest.fn(),
@@ -466,5 +466,257 @@ describe("SSEClientTransport", () => {
466466
expect(lastServerRequest.headers.authorization).toBe("Bearer test-token");
467467
expect(lastServerRequest.headers["x-custom-header"]).toBe("custom-value");
468468
});
469+
470+
it("refreshes expired token during SSE connection", async () => {
471+
// Mock tokens() to return expired token until saveTokens is called
472+
let currentTokens: OAuthTokens = {
473+
access_token: "expired-token",
474+
token_type: "Bearer",
475+
refresh_token: "refresh-token"
476+
};
477+
mockAuthProvider.tokens.mockImplementation(() => currentTokens);
478+
mockAuthProvider.saveTokens.mockImplementation((tokens) => {
479+
currentTokens = tokens;
480+
});
481+
482+
// Create server that returns 401 for expired token, then accepts new token
483+
await server.close();
484+
485+
let connectionAttempts = 0;
486+
server = createServer((req, res) => {
487+
lastServerRequest = req;
488+
489+
if (req.url === "/token" && req.method === "POST") {
490+
// Handle token refresh request
491+
let body = "";
492+
req.on("data", chunk => { body += chunk; });
493+
req.on("end", () => {
494+
const params = new URLSearchParams(body);
495+
if (params.get("grant_type") === "refresh_token" &&
496+
params.get("refresh_token") === "refresh-token" &&
497+
params.get("client_id") === "test-client-id" &&
498+
params.get("client_secret") === "test-client-secret") {
499+
res.writeHead(200, { "Content-Type": "application/json" });
500+
res.end(JSON.stringify({
501+
access_token: "new-token",
502+
token_type: "Bearer",
503+
refresh_token: "new-refresh-token"
504+
}));
505+
} else {
506+
res.writeHead(400).end();
507+
}
508+
});
509+
return;
510+
}
511+
512+
if (req.url !== "/") {
513+
res.writeHead(404).end();
514+
return;
515+
}
516+
517+
const auth = req.headers.authorization;
518+
if (auth === "Bearer expired-token") {
519+
res.writeHead(401).end();
520+
return;
521+
}
522+
523+
if (auth === "Bearer new-token") {
524+
res.writeHead(200, {
525+
"Content-Type": "text/event-stream",
526+
"Cache-Control": "no-cache",
527+
Connection: "keep-alive",
528+
});
529+
res.write("event: endpoint\n");
530+
res.write(`data: ${baseUrl.href}\n\n`);
531+
connectionAttempts++;
532+
return;
533+
}
534+
535+
res.writeHead(401).end();
536+
});
537+
538+
await new Promise<void>(resolve => {
539+
server.listen(0, "127.0.0.1", () => {
540+
const addr = server.address() as AddressInfo;
541+
baseUrl = new URL(`http://127.0.0.1:${addr.port}`);
542+
resolve();
543+
});
544+
});
545+
546+
transport = new SSEClientTransport(baseUrl, {
547+
authProvider: mockAuthProvider,
548+
});
549+
550+
await transport.start();
551+
552+
expect(mockAuthProvider.saveTokens).toHaveBeenCalledWith({
553+
access_token: "new-token",
554+
token_type: "Bearer",
555+
refresh_token: "new-refresh-token"
556+
});
557+
expect(connectionAttempts).toBe(1);
558+
expect(lastServerRequest.headers.authorization).toBe("Bearer new-token");
559+
});
560+
561+
it("refreshes expired token during POST request", async () => {
562+
// Mock tokens() to return expired token until saveTokens is called
563+
let currentTokens: OAuthTokens = {
564+
access_token: "expired-token",
565+
token_type: "Bearer",
566+
refresh_token: "refresh-token"
567+
};
568+
mockAuthProvider.tokens.mockImplementation(() => currentTokens);
569+
mockAuthProvider.saveTokens.mockImplementation((tokens) => {
570+
currentTokens = tokens;
571+
});
572+
573+
// Create server that accepts SSE but returns 401 on POST with expired token
574+
await server.close();
575+
576+
let postAttempts = 0;
577+
server = createServer((req, res) => {
578+
lastServerRequest = req;
579+
580+
if (req.url === "/token" && req.method === "POST") {
581+
// Handle token refresh request
582+
let body = "";
583+
req.on("data", chunk => { body += chunk; });
584+
req.on("end", () => {
585+
const params = new URLSearchParams(body);
586+
if (params.get("grant_type") === "refresh_token" &&
587+
params.get("refresh_token") === "refresh-token" &&
588+
params.get("client_id") === "test-client-id" &&
589+
params.get("client_secret") === "test-client-secret") {
590+
res.writeHead(200, { "Content-Type": "application/json" });
591+
res.end(JSON.stringify({
592+
access_token: "new-token",
593+
token_type: "Bearer",
594+
refresh_token: "new-refresh-token"
595+
}));
596+
} else {
597+
res.writeHead(400).end();
598+
}
599+
});
600+
return;
601+
}
602+
603+
switch (req.method) {
604+
case "GET":
605+
if (req.url !== "/") {
606+
res.writeHead(404).end();
607+
return;
608+
}
609+
610+
res.writeHead(200, {
611+
"Content-Type": "text/event-stream",
612+
"Cache-Control": "no-cache",
613+
Connection: "keep-alive",
614+
});
615+
res.write("event: endpoint\n");
616+
res.write(`data: ${baseUrl.href}\n\n`);
617+
break;
618+
619+
case "POST": {
620+
if (req.url !== "/") {
621+
res.writeHead(404).end();
622+
return;
623+
}
624+
625+
const auth = req.headers.authorization;
626+
if (auth === "Bearer expired-token") {
627+
res.writeHead(401).end();
628+
return;
629+
}
630+
631+
if (auth === "Bearer new-token") {
632+
res.writeHead(200).end();
633+
postAttempts++;
634+
return;
635+
}
636+
637+
res.writeHead(401).end();
638+
break;
639+
}
640+
}
641+
});
642+
643+
await new Promise<void>(resolve => {
644+
server.listen(0, "127.0.0.1", () => {
645+
const addr = server.address() as AddressInfo;
646+
baseUrl = new URL(`http://127.0.0.1:${addr.port}`);
647+
resolve();
648+
});
649+
});
650+
651+
transport = new SSEClientTransport(baseUrl, {
652+
authProvider: mockAuthProvider,
653+
});
654+
655+
await transport.start();
656+
657+
const message: JSONRPCMessage = {
658+
jsonrpc: "2.0",
659+
id: "1",
660+
method: "test",
661+
params: {},
662+
};
663+
664+
await transport.send(message);
665+
666+
expect(mockAuthProvider.saveTokens).toHaveBeenCalledWith({
667+
access_token: "new-token",
668+
token_type: "Bearer",
669+
refresh_token: "new-refresh-token"
670+
});
671+
expect(postAttempts).toBe(1);
672+
expect(lastServerRequest.headers.authorization).toBe("Bearer new-token");
673+
});
674+
675+
it("redirects to authorization if refresh token flow fails", async () => {
676+
// Mock tokens() to return expired token until saveTokens is called
677+
let currentTokens: OAuthTokens = {
678+
access_token: "expired-token",
679+
token_type: "Bearer",
680+
refresh_token: "refresh-token"
681+
};
682+
mockAuthProvider.tokens.mockImplementation(() => currentTokens);
683+
mockAuthProvider.saveTokens.mockImplementation((tokens) => {
684+
currentTokens = tokens;
685+
});
686+
687+
// Create server that returns 401 for all tokens
688+
await server.close();
689+
690+
server = createServer((req, res) => {
691+
lastServerRequest = req;
692+
693+
if (req.url === "/token" && req.method === "POST") {
694+
// Handle token refresh request - always fail
695+
res.writeHead(400).end();
696+
return;
697+
}
698+
699+
if (req.url !== "/") {
700+
res.writeHead(404).end();
701+
return;
702+
}
703+
res.writeHead(401).end();
704+
});
705+
706+
await new Promise<void>(resolve => {
707+
server.listen(0, "127.0.0.1", () => {
708+
const addr = server.address() as AddressInfo;
709+
baseUrl = new URL(`http://127.0.0.1:${addr.port}`);
710+
resolve();
711+
});
712+
});
713+
714+
transport = new SSEClientTransport(baseUrl, {
715+
authProvider: mockAuthProvider,
716+
});
717+
718+
await expect(transport.start()).rejects.toThrow("Unauthorized");
719+
expect(mockAuthProvider.redirectToAuthorization).toHaveBeenCalled();
720+
});
469721
});
470722
});

0 commit comments

Comments
 (0)