@@ -2,7 +2,7 @@ import { createServer, type IncomingMessage, type Server } from "http";
2
2
import { AddressInfo } from "net" ;
3
3
import { JSONRPCMessage } from "../types.js" ;
4
4
import { SSEClientTransport } from "./sse.js" ;
5
- import { auth , OAuthClientProvider } from "./auth.js" ;
5
+ import { OAuthClientProvider , OAuthTokens } from "./auth.js" ;
6
6
7
7
describe ( "SSEClientTransport" , ( ) => {
8
8
let server : Server ;
@@ -301,7 +301,7 @@ describe("SSEClientTransport", () => {
301
301
mockAuthProvider = {
302
302
get redirectUrl ( ) { return "http://localhost/callback" ; } ,
303
303
get clientMetadata ( ) { return { redirect_uris : [ "http://localhost/callback" ] } ; } ,
304
- clientInformation : jest . fn ( ( ) => ( { client_id : "test-client-id" } ) ) ,
304
+ clientInformation : jest . fn ( ( ) => ( { client_id : "test-client-id" , client_secret : "test-client-secret" } ) ) ,
305
305
tokens : jest . fn ( ) ,
306
306
saveTokens : jest . fn ( ) ,
307
307
redirectToAuthorization : jest . fn ( ) ,
@@ -466,5 +466,257 @@ describe("SSEClientTransport", () => {
466
466
expect ( lastServerRequest . headers . authorization ) . toBe ( "Bearer test-token" ) ;
467
467
expect ( lastServerRequest . headers [ "x-custom-header" ] ) . toBe ( "custom-value" ) ;
468
468
} ) ;
469
+
470
+ it ( "refreshes expired token during SSE connection" , async ( ) => {
471
+ // Mock tokens() to return expired token until saveTokens is called
472
+ let currentTokens : OAuthTokens = {
473
+ access_token : "expired-token" ,
474
+ token_type : "Bearer" ,
475
+ refresh_token : "refresh-token"
476
+ } ;
477
+ mockAuthProvider . tokens . mockImplementation ( ( ) => currentTokens ) ;
478
+ mockAuthProvider . saveTokens . mockImplementation ( ( tokens ) => {
479
+ currentTokens = tokens ;
480
+ } ) ;
481
+
482
+ // Create server that returns 401 for expired token, then accepts new token
483
+ await server . close ( ) ;
484
+
485
+ let connectionAttempts = 0 ;
486
+ server = createServer ( ( req , res ) => {
487
+ lastServerRequest = req ;
488
+
489
+ if ( req . url === "/token" && req . method === "POST" ) {
490
+ // Handle token refresh request
491
+ let body = "" ;
492
+ req . on ( "data" , chunk => { body += chunk ; } ) ;
493
+ req . on ( "end" , ( ) => {
494
+ const params = new URLSearchParams ( body ) ;
495
+ if ( params . get ( "grant_type" ) === "refresh_token" &&
496
+ params . get ( "refresh_token" ) === "refresh-token" &&
497
+ params . get ( "client_id" ) === "test-client-id" &&
498
+ params . get ( "client_secret" ) === "test-client-secret" ) {
499
+ res . writeHead ( 200 , { "Content-Type" : "application/json" } ) ;
500
+ res . end ( JSON . stringify ( {
501
+ access_token : "new-token" ,
502
+ token_type : "Bearer" ,
503
+ refresh_token : "new-refresh-token"
504
+ } ) ) ;
505
+ } else {
506
+ res . writeHead ( 400 ) . end ( ) ;
507
+ }
508
+ } ) ;
509
+ return ;
510
+ }
511
+
512
+ if ( req . url !== "/" ) {
513
+ res . writeHead ( 404 ) . end ( ) ;
514
+ return ;
515
+ }
516
+
517
+ const auth = req . headers . authorization ;
518
+ if ( auth === "Bearer expired-token" ) {
519
+ res . writeHead ( 401 ) . end ( ) ;
520
+ return ;
521
+ }
522
+
523
+ if ( auth === "Bearer new-token" ) {
524
+ res . writeHead ( 200 , {
525
+ "Content-Type" : "text/event-stream" ,
526
+ "Cache-Control" : "no-cache" ,
527
+ Connection : "keep-alive" ,
528
+ } ) ;
529
+ res . write ( "event: endpoint\n" ) ;
530
+ res . write ( `data: ${ baseUrl . href } \n\n` ) ;
531
+ connectionAttempts ++ ;
532
+ return ;
533
+ }
534
+
535
+ res . writeHead ( 401 ) . end ( ) ;
536
+ } ) ;
537
+
538
+ await new Promise < void > ( resolve => {
539
+ server . listen ( 0 , "127.0.0.1" , ( ) => {
540
+ const addr = server . address ( ) as AddressInfo ;
541
+ baseUrl = new URL ( `http://127.0.0.1:${ addr . port } ` ) ;
542
+ resolve ( ) ;
543
+ } ) ;
544
+ } ) ;
545
+
546
+ transport = new SSEClientTransport ( baseUrl , {
547
+ authProvider : mockAuthProvider ,
548
+ } ) ;
549
+
550
+ await transport . start ( ) ;
551
+
552
+ expect ( mockAuthProvider . saveTokens ) . toHaveBeenCalledWith ( {
553
+ access_token : "new-token" ,
554
+ token_type : "Bearer" ,
555
+ refresh_token : "new-refresh-token"
556
+ } ) ;
557
+ expect ( connectionAttempts ) . toBe ( 1 ) ;
558
+ expect ( lastServerRequest . headers . authorization ) . toBe ( "Bearer new-token" ) ;
559
+ } ) ;
560
+
561
+ it ( "refreshes expired token during POST request" , async ( ) => {
562
+ // Mock tokens() to return expired token until saveTokens is called
563
+ let currentTokens : OAuthTokens = {
564
+ access_token : "expired-token" ,
565
+ token_type : "Bearer" ,
566
+ refresh_token : "refresh-token"
567
+ } ;
568
+ mockAuthProvider . tokens . mockImplementation ( ( ) => currentTokens ) ;
569
+ mockAuthProvider . saveTokens . mockImplementation ( ( tokens ) => {
570
+ currentTokens = tokens ;
571
+ } ) ;
572
+
573
+ // Create server that accepts SSE but returns 401 on POST with expired token
574
+ await server . close ( ) ;
575
+
576
+ let postAttempts = 0 ;
577
+ server = createServer ( ( req , res ) => {
578
+ lastServerRequest = req ;
579
+
580
+ if ( req . url === "/token" && req . method === "POST" ) {
581
+ // Handle token refresh request
582
+ let body = "" ;
583
+ req . on ( "data" , chunk => { body += chunk ; } ) ;
584
+ req . on ( "end" , ( ) => {
585
+ const params = new URLSearchParams ( body ) ;
586
+ if ( params . get ( "grant_type" ) === "refresh_token" &&
587
+ params . get ( "refresh_token" ) === "refresh-token" &&
588
+ params . get ( "client_id" ) === "test-client-id" &&
589
+ params . get ( "client_secret" ) === "test-client-secret" ) {
590
+ res . writeHead ( 200 , { "Content-Type" : "application/json" } ) ;
591
+ res . end ( JSON . stringify ( {
592
+ access_token : "new-token" ,
593
+ token_type : "Bearer" ,
594
+ refresh_token : "new-refresh-token"
595
+ } ) ) ;
596
+ } else {
597
+ res . writeHead ( 400 ) . end ( ) ;
598
+ }
599
+ } ) ;
600
+ return ;
601
+ }
602
+
603
+ switch ( req . method ) {
604
+ case "GET" :
605
+ if ( req . url !== "/" ) {
606
+ res . writeHead ( 404 ) . end ( ) ;
607
+ return ;
608
+ }
609
+
610
+ res . writeHead ( 200 , {
611
+ "Content-Type" : "text/event-stream" ,
612
+ "Cache-Control" : "no-cache" ,
613
+ Connection : "keep-alive" ,
614
+ } ) ;
615
+ res . write ( "event: endpoint\n" ) ;
616
+ res . write ( `data: ${ baseUrl . href } \n\n` ) ;
617
+ break ;
618
+
619
+ case "POST" : {
620
+ if ( req . url !== "/" ) {
621
+ res . writeHead ( 404 ) . end ( ) ;
622
+ return ;
623
+ }
624
+
625
+ const auth = req . headers . authorization ;
626
+ if ( auth === "Bearer expired-token" ) {
627
+ res . writeHead ( 401 ) . end ( ) ;
628
+ return ;
629
+ }
630
+
631
+ if ( auth === "Bearer new-token" ) {
632
+ res . writeHead ( 200 ) . end ( ) ;
633
+ postAttempts ++ ;
634
+ return ;
635
+ }
636
+
637
+ res . writeHead ( 401 ) . end ( ) ;
638
+ break ;
639
+ }
640
+ }
641
+ } ) ;
642
+
643
+ await new Promise < void > ( resolve => {
644
+ server . listen ( 0 , "127.0.0.1" , ( ) => {
645
+ const addr = server . address ( ) as AddressInfo ;
646
+ baseUrl = new URL ( `http://127.0.0.1:${ addr . port } ` ) ;
647
+ resolve ( ) ;
648
+ } ) ;
649
+ } ) ;
650
+
651
+ transport = new SSEClientTransport ( baseUrl , {
652
+ authProvider : mockAuthProvider ,
653
+ } ) ;
654
+
655
+ await transport . start ( ) ;
656
+
657
+ const message : JSONRPCMessage = {
658
+ jsonrpc : "2.0" ,
659
+ id : "1" ,
660
+ method : "test" ,
661
+ params : { } ,
662
+ } ;
663
+
664
+ await transport . send ( message ) ;
665
+
666
+ expect ( mockAuthProvider . saveTokens ) . toHaveBeenCalledWith ( {
667
+ access_token : "new-token" ,
668
+ token_type : "Bearer" ,
669
+ refresh_token : "new-refresh-token"
670
+ } ) ;
671
+ expect ( postAttempts ) . toBe ( 1 ) ;
672
+ expect ( lastServerRequest . headers . authorization ) . toBe ( "Bearer new-token" ) ;
673
+ } ) ;
674
+
675
+ it ( "redirects to authorization if refresh token flow fails" , async ( ) => {
676
+ // Mock tokens() to return expired token until saveTokens is called
677
+ let currentTokens : OAuthTokens = {
678
+ access_token : "expired-token" ,
679
+ token_type : "Bearer" ,
680
+ refresh_token : "refresh-token"
681
+ } ;
682
+ mockAuthProvider . tokens . mockImplementation ( ( ) => currentTokens ) ;
683
+ mockAuthProvider . saveTokens . mockImplementation ( ( tokens ) => {
684
+ currentTokens = tokens ;
685
+ } ) ;
686
+
687
+ // Create server that returns 401 for all tokens
688
+ await server . close ( ) ;
689
+
690
+ server = createServer ( ( req , res ) => {
691
+ lastServerRequest = req ;
692
+
693
+ if ( req . url === "/token" && req . method === "POST" ) {
694
+ // Handle token refresh request - always fail
695
+ res . writeHead ( 400 ) . end ( ) ;
696
+ return ;
697
+ }
698
+
699
+ if ( req . url !== "/" ) {
700
+ res . writeHead ( 404 ) . end ( ) ;
701
+ return ;
702
+ }
703
+ res . writeHead ( 401 ) . end ( ) ;
704
+ } ) ;
705
+
706
+ await new Promise < void > ( resolve => {
707
+ server . listen ( 0 , "127.0.0.1" , ( ) => {
708
+ const addr = server . address ( ) as AddressInfo ;
709
+ baseUrl = new URL ( `http://127.0.0.1:${ addr . port } ` ) ;
710
+ resolve ( ) ;
711
+ } ) ;
712
+ } ) ;
713
+
714
+ transport = new SSEClientTransport ( baseUrl , {
715
+ authProvider : mockAuthProvider ,
716
+ } ) ;
717
+
718
+ await expect ( transport . start ( ) ) . rejects . toThrow ( "Unauthorized" ) ;
719
+ expect ( mockAuthProvider . redirectToAuthorization ) . toHaveBeenCalled ( ) ;
720
+ } ) ;
469
721
} ) ;
470
722
} ) ;
0 commit comments