@@ -2,9 +2,11 @@ package server
2
2
3
3
import (
4
4
"bytes"
5
+ "context"
5
6
"encoding/json"
6
7
"fmt"
7
8
"net/http"
9
+ "net/http/httptest"
8
10
"strings"
9
11
"sync"
10
12
"testing"
@@ -229,4 +231,178 @@ func TestSSEServer(t *testing.T) {
229
231
t .Fatal ("Timeout waiting for sessions to complete" )
230
232
}
231
233
})
234
+
235
+ t .Run ("Can be used as http.Handler" , func (t * testing.T ) {
236
+ mcpServer := NewMCPServer ("test" , "1.0.0" )
237
+ sseServer := NewSSEServer (mcpServer , "http://localhost:8080" )
238
+
239
+ ts := httptest .NewServer (sseServer )
240
+ defer ts .Close ()
241
+
242
+ // Test 404 for unknown path first (simpler case)
243
+ resp , err := http .Get (fmt .Sprintf ("%s/unknown" , ts .URL ))
244
+ if err != nil {
245
+ t .Fatalf ("Failed to make request: %v" , err )
246
+ }
247
+ defer resp .Body .Close ()
248
+ if resp .StatusCode != http .StatusNotFound {
249
+ t .Errorf ("Expected status 404, got %d" , resp .StatusCode )
250
+ }
251
+
252
+ // Test SSE endpoint with proper cleanup
253
+ ctx , cancel := context .WithCancel (context .Background ())
254
+ defer cancel ()
255
+
256
+ req , err := http .NewRequestWithContext (ctx , "GET" , fmt .Sprintf ("%s/sse" , ts .URL ), nil )
257
+ if err != nil {
258
+ t .Fatalf ("Failed to create request: %v" , err )
259
+ }
260
+
261
+ resp , err = http .DefaultClient .Do (req )
262
+ if err != nil {
263
+ t .Fatalf ("Failed to connect to SSE endpoint: %v" , err )
264
+ }
265
+ defer resp .Body .Close ()
266
+
267
+ if resp .StatusCode != http .StatusOK {
268
+ t .Errorf ("Expected status 200, got %d" , resp .StatusCode )
269
+ }
270
+
271
+ // Read initial message in goroutine
272
+ done := make (chan struct {})
273
+ go func () {
274
+ defer close (done )
275
+ buf := make ([]byte , 1024 )
276
+ _ , err := resp .Body .Read (buf )
277
+ if err != nil && err .Error () != "context canceled" {
278
+ t .Errorf ("Failed to read from SSE stream: %v" , err )
279
+ }
280
+ }()
281
+
282
+ // Wait briefly for initial response then cancel
283
+ time .Sleep (100 * time .Millisecond )
284
+ cancel ()
285
+ <- done
286
+ })
287
+
288
+ t .Run ("Works with middleware" , func (t * testing.T ) {
289
+ mcpServer := NewMCPServer ("test" , "1.0.0" )
290
+ sseServer := NewSSEServer (mcpServer , "http://localhost:8080" )
291
+
292
+ middleware := func (next http.Handler ) http.Handler {
293
+ return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
294
+ w .Header ().Set ("X-Test" , "middleware" )
295
+ next .ServeHTTP (w , r )
296
+ })
297
+ }
298
+
299
+ ts := httptest .NewServer (middleware (sseServer ))
300
+ defer ts .Close ()
301
+
302
+ ctx , cancel := context .WithCancel (context .Background ())
303
+ defer cancel ()
304
+
305
+ req , err := http .NewRequestWithContext (ctx , "GET" , fmt .Sprintf ("%s/sse" , ts .URL ), nil )
306
+ if err != nil {
307
+ t .Fatalf ("Failed to create request: %v" , err )
308
+ }
309
+
310
+ resp , err := http .DefaultClient .Do (req )
311
+ if err != nil {
312
+ t .Fatalf ("Failed to connect to SSE endpoint: %v" , err )
313
+ }
314
+ defer resp .Body .Close ()
315
+
316
+ if resp .Header .Get ("X-Test" ) != "middleware" {
317
+ t .Error ("Middleware header not found" )
318
+ }
319
+
320
+ // Read initial message in goroutine
321
+ done := make (chan struct {})
322
+ go func () {
323
+ defer close (done )
324
+ buf := make ([]byte , 1024 )
325
+ _ , err := resp .Body .Read (buf )
326
+ if err != nil && err .Error () != "context canceled" {
327
+ t .Errorf ("Failed to read from SSE stream: %v" , err )
328
+ }
329
+ }()
330
+
331
+ // Wait briefly then cancel
332
+ time .Sleep (100 * time .Millisecond )
333
+ cancel ()
334
+ <- done
335
+ })
336
+
337
+ t .Run ("Works with custom mux" , func (t * testing.T ) {
338
+ mcpServer := NewMCPServer ("test" , "1.0.0" )
339
+ sseServer := NewSSEServer (mcpServer , "" )
340
+
341
+ mux := http .NewServeMux ()
342
+ mux .Handle ("/mcp/" , http .StripPrefix ("/mcp" , sseServer ))
343
+
344
+ ts := httptest .NewServer (mux )
345
+ defer ts .Close ()
346
+
347
+ sseServer .baseURL = ts .URL + "/mcp"
348
+
349
+ ctx , cancel := context .WithCancel (context .Background ())
350
+ defer cancel ()
351
+
352
+ req , err := http .NewRequestWithContext (ctx , "GET" , fmt .Sprintf ("%s/mcp/sse" , ts .URL ), nil )
353
+ if err != nil {
354
+ t .Fatalf ("Failed to create request: %v" , err )
355
+ }
356
+
357
+ resp , err := http .DefaultClient .Do (req )
358
+ if err != nil {
359
+ t .Fatalf ("Failed to connect to SSE endpoint: %v" , err )
360
+ }
361
+ defer resp .Body .Close ()
362
+
363
+ if resp .StatusCode != http .StatusOK {
364
+ t .Errorf ("Expected status 200, got %d" , resp .StatusCode )
365
+ }
366
+
367
+ // Read the endpoint event
368
+ buf := make ([]byte , 1024 )
369
+ n , err := resp .Body .Read (buf )
370
+ if err != nil {
371
+ t .Fatalf ("Failed to read SSE response: %v" , err )
372
+ }
373
+
374
+ endpointEvent := string (buf [:n ])
375
+ messageURL := strings .TrimSpace (
376
+ strings .Split (strings .Split (endpointEvent , "data: " )[1 ], "\n " )[0 ],
377
+ )
378
+
379
+ // The messageURL should already be correct since we set the baseURL correctly
380
+ // Test message endpoint
381
+ initRequest := map [string ]interface {}{
382
+ "jsonrpc" : "2.0" ,
383
+ "id" : 1 ,
384
+ "method" : "initialize" ,
385
+ "params" : map [string ]interface {}{
386
+ "protocolVersion" : "2024-11-05" ,
387
+ "clientInfo" : map [string ]interface {}{
388
+ "name" : "test-client" ,
389
+ "version" : "1.0.0" ,
390
+ },
391
+ },
392
+ }
393
+ requestBody , _ := json .Marshal (initRequest )
394
+
395
+ resp , err = http .Post (messageURL , "application/json" , bytes .NewBuffer (requestBody ))
396
+ if err != nil {
397
+ t .Fatalf ("Failed to send message: %v" , err )
398
+ }
399
+ defer resp .Body .Close ()
400
+
401
+ if resp .StatusCode != http .StatusAccepted {
402
+ t .Errorf ("Expected status 202, got %d" , resp .StatusCode )
403
+ }
404
+
405
+ // Clean up SSE connection
406
+ cancel ()
407
+ })
232
408
}
0 commit comments