Skip to content

Commit 3a71167

Browse files
committed
Expose SSE server as HttpHandler
1 parent 7e6fe09 commit 3a71167

File tree

2 files changed

+188
-0
lines changed

2 files changed

+188
-0
lines changed

server/sse.go

+12
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,15 @@ func (s *SSEServer) SendEventToSession(
244244
return nil
245245
}
246246
}
247+
248+
// ServeHTTP implements the http.Handler interface.
249+
func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
250+
switch r.URL.Path {
251+
case "/sse":
252+
s.handleSSE(w, r)
253+
case "/message":
254+
s.handleMessage(w, r)
255+
default:
256+
http.NotFound(w, r)
257+
}
258+
}

server/sse_test.go

+176
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ package server
22

33
import (
44
"bytes"
5+
"context"
56
"encoding/json"
67
"fmt"
78
"net/http"
9+
"net/http/httptest"
810
"strings"
911
"sync"
1012
"testing"
@@ -229,4 +231,178 @@ func TestSSEServer(t *testing.T) {
229231
t.Fatal("Timeout waiting for sessions to complete")
230232
}
231233
})
234+
235+
t.Run("Can be used as http.Handler", func(t *testing.T) {
236+
mcpServer := NewMCPServer("test", "1.0.0")
237+
sseServer := NewSSEServer(mcpServer, "http://localhost:8080")
238+
239+
ts := httptest.NewServer(sseServer)
240+
defer ts.Close()
241+
242+
// Test 404 for unknown path first (simpler case)
243+
resp, err := http.Get(fmt.Sprintf("%s/unknown", ts.URL))
244+
if err != nil {
245+
t.Fatalf("Failed to make request: %v", err)
246+
}
247+
defer resp.Body.Close()
248+
if resp.StatusCode != http.StatusNotFound {
249+
t.Errorf("Expected status 404, got %d", resp.StatusCode)
250+
}
251+
252+
// Test SSE endpoint with proper cleanup
253+
ctx, cancel := context.WithCancel(context.Background())
254+
defer cancel()
255+
256+
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/sse", ts.URL), nil)
257+
if err != nil {
258+
t.Fatalf("Failed to create request: %v", err)
259+
}
260+
261+
resp, err = http.DefaultClient.Do(req)
262+
if err != nil {
263+
t.Fatalf("Failed to connect to SSE endpoint: %v", err)
264+
}
265+
defer resp.Body.Close()
266+
267+
if resp.StatusCode != http.StatusOK {
268+
t.Errorf("Expected status 200, got %d", resp.StatusCode)
269+
}
270+
271+
// Read initial message in goroutine
272+
done := make(chan struct{})
273+
go func() {
274+
defer close(done)
275+
buf := make([]byte, 1024)
276+
_, err := resp.Body.Read(buf)
277+
if err != nil && err.Error() != "context canceled" {
278+
t.Errorf("Failed to read from SSE stream: %v", err)
279+
}
280+
}()
281+
282+
// Wait briefly for initial response then cancel
283+
time.Sleep(100 * time.Millisecond)
284+
cancel()
285+
<-done
286+
})
287+
288+
t.Run("Works with middleware", func(t *testing.T) {
289+
mcpServer := NewMCPServer("test", "1.0.0")
290+
sseServer := NewSSEServer(mcpServer, "http://localhost:8080")
291+
292+
middleware := func(next http.Handler) http.Handler {
293+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
294+
w.Header().Set("X-Test", "middleware")
295+
next.ServeHTTP(w, r)
296+
})
297+
}
298+
299+
ts := httptest.NewServer(middleware(sseServer))
300+
defer ts.Close()
301+
302+
ctx, cancel := context.WithCancel(context.Background())
303+
defer cancel()
304+
305+
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/sse", ts.URL), nil)
306+
if err != nil {
307+
t.Fatalf("Failed to create request: %v", err)
308+
}
309+
310+
resp, err := http.DefaultClient.Do(req)
311+
if err != nil {
312+
t.Fatalf("Failed to connect to SSE endpoint: %v", err)
313+
}
314+
defer resp.Body.Close()
315+
316+
if resp.Header.Get("X-Test") != "middleware" {
317+
t.Error("Middleware header not found")
318+
}
319+
320+
// Read initial message in goroutine
321+
done := make(chan struct{})
322+
go func() {
323+
defer close(done)
324+
buf := make([]byte, 1024)
325+
_, err := resp.Body.Read(buf)
326+
if err != nil && err.Error() != "context canceled" {
327+
t.Errorf("Failed to read from SSE stream: %v", err)
328+
}
329+
}()
330+
331+
// Wait briefly then cancel
332+
time.Sleep(100 * time.Millisecond)
333+
cancel()
334+
<-done
335+
})
336+
337+
t.Run("Works with custom mux", func(t *testing.T) {
338+
mcpServer := NewMCPServer("test", "1.0.0")
339+
sseServer := NewSSEServer(mcpServer, "")
340+
341+
mux := http.NewServeMux()
342+
mux.Handle("/mcp/", http.StripPrefix("/mcp", sseServer))
343+
344+
ts := httptest.NewServer(mux)
345+
defer ts.Close()
346+
347+
sseServer.baseURL = ts.URL + "/mcp"
348+
349+
ctx, cancel := context.WithCancel(context.Background())
350+
defer cancel()
351+
352+
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/mcp/sse", ts.URL), nil)
353+
if err != nil {
354+
t.Fatalf("Failed to create request: %v", err)
355+
}
356+
357+
resp, err := http.DefaultClient.Do(req)
358+
if err != nil {
359+
t.Fatalf("Failed to connect to SSE endpoint: %v", err)
360+
}
361+
defer resp.Body.Close()
362+
363+
if resp.StatusCode != http.StatusOK {
364+
t.Errorf("Expected status 200, got %d", resp.StatusCode)
365+
}
366+
367+
// Read the endpoint event
368+
buf := make([]byte, 1024)
369+
n, err := resp.Body.Read(buf)
370+
if err != nil {
371+
t.Fatalf("Failed to read SSE response: %v", err)
372+
}
373+
374+
endpointEvent := string(buf[:n])
375+
messageURL := strings.TrimSpace(
376+
strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0],
377+
)
378+
379+
// The messageURL should already be correct since we set the baseURL correctly
380+
// Test message endpoint
381+
initRequest := map[string]interface{}{
382+
"jsonrpc": "2.0",
383+
"id": 1,
384+
"method": "initialize",
385+
"params": map[string]interface{}{
386+
"protocolVersion": "2024-11-05",
387+
"clientInfo": map[string]interface{}{
388+
"name": "test-client",
389+
"version": "1.0.0",
390+
},
391+
},
392+
}
393+
requestBody, _ := json.Marshal(initRequest)
394+
395+
resp, err = http.Post(messageURL, "application/json", bytes.NewBuffer(requestBody))
396+
if err != nil {
397+
t.Fatalf("Failed to send message: %v", err)
398+
}
399+
defer resp.Body.Close()
400+
401+
if resp.StatusCode != http.StatusAccepted {
402+
t.Errorf("Expected status 202, got %d", resp.StatusCode)
403+
}
404+
405+
// Clean up SSE connection
406+
cancel()
407+
})
232408
}

0 commit comments

Comments
 (0)