Skip to content

Commit 258aee9

Browse files
authored
Option mode (#39)
* Option mode * fix test
1 parent 0b8eba6 commit 258aee9

File tree

4 files changed

+69
-37
lines changed

4 files changed

+69
-37
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
.aider*
22
.env
3+
.idea

examples/everything/main.go

+2-11
Original file line numberDiff line numberDiff line change
@@ -321,10 +321,6 @@ func handleSendNotification(
321321
}, nil
322322
}
323323

324-
func ServeSSE(mcpServer *server.MCPServer, addr string) *server.SSEServer {
325-
return server.NewSSEServer(mcpServer, fmt.Sprintf("http://%s", addr))
326-
}
327-
328324
func handleLongRunningOperationTool(
329325
ctx context.Context,
330326
request mcp.CallToolRequest,
@@ -418,19 +414,14 @@ func handleNotification(
418414
func main() {
419415
var transport string
420416
flag.StringVar(&transport, "t", "stdio", "Transport type (stdio or sse)")
421-
flag.StringVar(
422-
&transport,
423-
"transport",
424-
"stdio",
425-
"Transport type (stdio or sse)",
426-
)
417+
flag.StringVar(&transport, "transport", "stdio", "Transport type (stdio or sse)")
427418
flag.Parse()
428419

429420
mcpServer := NewMCPServer()
430421

431422
// Only check for "sse" since stdio is the default
432423
if transport == "sse" {
433-
sseServer := ServeSSE(mcpServer, "localhost:8080")
424+
sseServer := server.NewSSEServer(mcpServer)
434425
log.Printf("SSE server listening on :8080")
435426
if err := sseServer.Start(":8080"); err != nil {
436427
log.Fatalf("Server error: %v", err)

server/sse.go

+60-20
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,6 @@ import (
1212
"github.com/mark3labs/mcp-go/mcp"
1313
)
1414

15-
// SSEServer implements a Server-Sent Events (SSE) based MCP server.
16-
// It provides real-time communication capabilities over HTTP using the SSE protocol.
17-
type SSEServer struct {
18-
server *MCPServer
19-
baseURL string
20-
sessions sync.Map
21-
srv *http.Server
22-
}
23-
2415
// sseSession represents an active SSE connection.
2516
type sseSession struct {
2617
writer http.ResponseWriter
@@ -29,19 +20,67 @@ type sseSession struct {
2920
eventQueue chan string // Channel for queuing events
3021
}
3122

32-
// NewSSEServer creates a new SSE server instance with the given MCP server and base URL.
33-
func NewSSEServer(server *MCPServer, baseURL string) *SSEServer {
34-
return &SSEServer{
35-
server: server,
36-
baseURL: baseURL,
23+
// SSEServer implements a Server-Sent Events (SSE) based MCP server.
24+
// It provides real-time communication capabilities over HTTP using the SSE protocol.
25+
type SSEServer struct {
26+
server *MCPServer
27+
baseURL string
28+
messageEndpoint string
29+
sseEndpoint string
30+
sessions sync.Map
31+
srv *http.Server
32+
}
33+
34+
// Option defines a function type for configuring SSEServer
35+
type Option func(*SSEServer)
36+
37+
// WithBaseURL sets the base URL for the SSE server
38+
func WithBaseURL(baseURL string) Option {
39+
return func(s *SSEServer) {
40+
s.baseURL = baseURL
41+
}
42+
}
43+
44+
// WithMessageEndpoint sets the message endpoint path
45+
func WithMessageEndpoint(endpoint string) Option {
46+
return func(s *SSEServer) {
47+
s.messageEndpoint = endpoint
3748
}
3849
}
3950

51+
// WithSSEEndpoint sets the SSE endpoint path
52+
func WithSSEEndpoint(endpoint string) Option {
53+
return func(s *SSEServer) {
54+
s.sseEndpoint = endpoint
55+
}
56+
}
57+
58+
// WithHTTPServer sets the HTTP server instance
59+
func WithHTTPServer(srv *http.Server) Option {
60+
return func(s *SSEServer) {
61+
s.srv = srv
62+
}
63+
}
64+
65+
// NewSSEServer creates a new SSE server instance with the given MCP server and options.
66+
func NewSSEServer(server *MCPServer, opts ...Option) *SSEServer {
67+
s := &SSEServer{
68+
server: server,
69+
sseEndpoint: "/sse",
70+
messageEndpoint: "/message",
71+
}
72+
73+
// Apply all options
74+
for _, opt := range opts {
75+
opt(s)
76+
}
77+
78+
return s
79+
}
80+
4081
// NewTestServer creates a test server for testing purposes
4182
func NewTestServer(server *MCPServer) *httptest.Server {
42-
sseServer := &SSEServer{
43-
server: server,
44-
}
83+
sseServer := NewSSEServer(server)
4584

4685
testServer := httptest.NewServer(sseServer)
4786
sseServer.baseURL = testServer.URL
@@ -132,8 +171,9 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
132171
}()
133172

134173
messageEndpoint := fmt.Sprintf(
135-
"%s/message?sessionId=%s",
174+
"%s%s?sessionId=%s",
136175
s.baseURL,
176+
s.messageEndpoint,
137177
sessionID,
138178
)
139179

@@ -260,9 +300,9 @@ func (s *SSEServer) SendEventToSession(
260300
// ServeHTTP implements the http.Handler interface.
261301
func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
262302
switch r.URL.Path {
263-
case "/sse":
303+
case s.sseEndpoint:
264304
s.handleSSE(w, r)
265-
case "/message":
305+
case s.messageEndpoint:
266306
s.handleMessage(w, r)
267307
default:
268308
http.NotFound(w, r)

server/sse_test.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import (
1616
func TestSSEServer(t *testing.T) {
1717
t.Run("Can instantiate", func(t *testing.T) {
1818
mcpServer := NewMCPServer("test", "1.0.0")
19-
sseServer := NewSSEServer(mcpServer, "http://localhost:8080")
19+
sseServer := NewSSEServer(mcpServer, WithBaseURL("http://localhost:8080"))
2020

2121
if sseServer == nil {
2222
t.Error("SSEServer should not be nil")
@@ -234,7 +234,7 @@ func TestSSEServer(t *testing.T) {
234234

235235
t.Run("Can be used as http.Handler", func(t *testing.T) {
236236
mcpServer := NewMCPServer("test", "1.0.0")
237-
sseServer := NewSSEServer(mcpServer, "http://localhost:8080")
237+
sseServer := NewSSEServer(mcpServer, WithBaseURL("http://localhost:8080"))
238238

239239
ts := httptest.NewServer(sseServer)
240240
defer ts.Close()
@@ -263,7 +263,7 @@ func TestSSEServer(t *testing.T) {
263263
t.Fatalf("Failed to connect to SSE endpoint: %v", err)
264264
}
265265
defer resp.Body.Close()
266-
266+
267267
if resp.StatusCode != http.StatusOK {
268268
t.Errorf("Expected status 200, got %d", resp.StatusCode)
269269
}
@@ -287,7 +287,7 @@ func TestSSEServer(t *testing.T) {
287287

288288
t.Run("Works with middleware", func(t *testing.T) {
289289
mcpServer := NewMCPServer("test", "1.0.0")
290-
sseServer := NewSSEServer(mcpServer, "http://localhost:8080")
290+
sseServer := NewSSEServer(mcpServer, WithBaseURL("http://localhost:8080"))
291291

292292
middleware := func(next http.Handler) http.Handler {
293293
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -336,7 +336,7 @@ func TestSSEServer(t *testing.T) {
336336

337337
t.Run("Works with custom mux", func(t *testing.T) {
338338
mcpServer := NewMCPServer("test", "1.0.0")
339-
sseServer := NewSSEServer(mcpServer, "")
339+
sseServer := NewSSEServer(mcpServer)
340340

341341
mux := http.NewServeMux()
342342
mux.Handle("/mcp/", http.StripPrefix("/mcp", sseServer))
@@ -397,7 +397,7 @@ func TestSSEServer(t *testing.T) {
397397
t.Fatalf("Failed to send message: %v", err)
398398
}
399399
defer resp.Body.Close()
400-
400+
401401
if resp.StatusCode != http.StatusAccepted {
402402
t.Errorf("Expected status 202, got %d", resp.StatusCode)
403403
}

0 commit comments

Comments
 (0)