Skip to content

Commit d00c215

Browse files
committed
feat: add custom sse router
1 parent a0e968a commit d00c215

File tree

2 files changed

+163
-12
lines changed

2 files changed

+163
-12
lines changed

examples/custom_sse_pattern/main.go

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"log"
7+
"net/http"
8+
9+
"github.com/mark3labs/mcp-go/mcp"
10+
"github.com/mark3labs/mcp-go/server"
11+
)
12+
13+
// Custom context function for SSE connections
14+
func customContextFunc(ctx context.Context, r *http.Request) context.Context {
15+
params := server.GetRouteParams(ctx)
16+
log.Printf("SSE Connection Established - Route Parameters: %+v", params)
17+
log.Printf("Request Path: %s", r.URL.Path)
18+
return ctx
19+
}
20+
21+
// Message handler for simulating message sending
22+
func messageHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
23+
// Get channel parameter from context
24+
channel := server.GetRouteParam(ctx, "channel")
25+
log.Printf("Processing Message - Channel Parameter: %s", channel)
26+
27+
if channel == "" {
28+
return mcp.NewToolResultText("Failed to get channel parameter"), nil
29+
}
30+
31+
message := fmt.Sprintf("Message sent to channel: %s", channel)
32+
return mcp.NewToolResultText(message), nil
33+
}
34+
35+
func main() {
36+
// Create MCP Server
37+
mcpServer := server.NewMCPServer("test-server", "1.0.0")
38+
39+
// Register test tool
40+
mcpServer.AddTool(mcp.NewTool("send_message"), messageHandler)
41+
42+
// Create SSE Server with custom route pattern
43+
sseServer := server.NewSSEServer(mcpServer,
44+
server.WithBaseURL("http://localhost:8080"),
45+
server.WithSSEPattern("/:channel/sse"),
46+
server.WithSSEContextFunc(customContextFunc),
47+
)
48+
49+
// Start server
50+
log.Printf("Server started on port :8080")
51+
log.Printf("Test URL: http://localhost:8080/test/sse")
52+
log.Printf("Test URL: http://localhost:8080/news/sse")
53+
54+
if err := sseServer.Start(":8080"); err != nil {
55+
log.Fatalf("Server error: %v", err)
56+
}
57+
}

server/sse.go

+106-12
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,36 @@ type sseSession struct {
2424
sessionID string
2525
notificationChannel chan mcp.JSONRPCNotification
2626
initialized atomic.Bool
27+
routeParams RouteParams // Store route parameters in session
2728
}
2829

2930
// SSEContextFunc is a function that takes an existing context and the current
3031
// request and returns a potentially modified context based on the request
3132
// content. This can be used to inject context values from headers, for example.
3233
type SSEContextFunc func(ctx context.Context, r *http.Request) context.Context
3334

35+
// RouteParamsKey is the key type for storing route parameters in context
36+
type RouteParamsKey struct{}
37+
38+
// RouteParams stores path parameters
39+
type RouteParams map[string]string
40+
41+
// GetRouteParam retrieves a route parameter from context
42+
func GetRouteParam(ctx context.Context, key string) string {
43+
if params, ok := ctx.Value(RouteParamsKey{}).(RouteParams); ok {
44+
return params[key]
45+
}
46+
return ""
47+
}
48+
49+
// GetRouteParams retrieves all route parameters from context
50+
func GetRouteParams(ctx context.Context) RouteParams {
51+
if params, ok := ctx.Value(RouteParamsKey{}).(RouteParams); ok {
52+
return params
53+
}
54+
return RouteParams{}
55+
}
56+
3457
func (s *sseSession) SessionID() string {
3558
return s.sessionID
3659
}
@@ -58,6 +81,7 @@ type SSEServer struct {
5881
messageEndpoint string
5982
useFullURLForMessageEndpoint bool
6083
sseEndpoint string
84+
ssePattern string
6185
sessions sync.Map
6286
srv *http.Server
6387
contextFunc SSEContextFunc
@@ -123,14 +147,21 @@ func WithSSEEndpoint(endpoint string) SSEOption {
123147
}
124148
}
125149

150+
// WithSSEPattern sets the SSE endpoint pattern with route parameters
151+
func WithSSEPattern(pattern string) SSEOption {
152+
return func(s *SSEServer) {
153+
s.ssePattern = pattern
154+
}
155+
}
156+
126157
// WithHTTPServer sets the HTTP server instance
127158
func WithHTTPServer(srv *http.Server) SSEOption {
128159
return func(s *SSEServer) {
129160
s.srv = srv
130161
}
131162
}
132163

133-
// WithContextFunc sets a function that will be called to customise the context
164+
// WithSSEContextFunc sets a function that will be called to customise the context
134165
// to the server using the incoming request.
135166
func WithSSEContextFunc(fn SSEContextFunc) SSEOption {
136167
return func(s *SSEServer) {
@@ -222,12 +253,21 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
222253
eventQueue: make(chan string, 100), // Buffer for events
223254
sessionID: sessionID,
224255
notificationChannel: make(chan mcp.JSONRPCNotification, 100),
256+
routeParams: GetRouteParams(r.Context()), // Store route parameters from context
225257
}
226258

227259
s.sessions.Store(sessionID, session)
228260
defer s.sessions.Delete(sessionID)
229261

230-
if err := s.server.RegisterSession(r.Context(), session); err != nil {
262+
// Create base context with session
263+
ctx := s.server.WithContext(r.Context(), session)
264+
265+
// Apply custom context function if set
266+
if s.contextFunc != nil {
267+
ctx = s.contextFunc(ctx, r)
268+
}
269+
270+
if err := s.server.RegisterSession(ctx, session); err != nil {
231271
http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusInternalServerError)
232272
return
233273
}
@@ -249,7 +289,7 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
249289
}
250290
case <-session.done:
251291
return
252-
case <-r.Context().Done():
292+
case <-ctx.Done():
253293
return
254294
}
255295
}
@@ -266,7 +306,7 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
266306
// Write the event to the response
267307
fmt.Fprint(w, event)
268308
flusher.Flush()
269-
case <-r.Context().Done():
309+
case <-ctx.Done():
270310
close(session.done)
271311
return
272312
}
@@ -304,8 +344,15 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
304344
}
305345
session := sessionI.(*sseSession)
306346

307-
// Set the client context before handling the message
347+
// Create base context with session
308348
ctx := s.server.WithContext(r.Context(), session)
349+
350+
// Add stored route parameters to context
351+
if len(session.routeParams) > 0 {
352+
ctx = context.WithValue(ctx, RouteParamsKey{}, session.routeParams)
353+
}
354+
355+
// Apply custom context function if set
309356
if s.contextFunc != nil {
310357
ctx = s.contextFunc(ctx, r)
311358
}
@@ -317,7 +364,7 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
317364
return
318365
}
319366

320-
// Process message through MCPServer
367+
// Process message through MCPServer with the context containing route parameters
321368
response := s.server.HandleMessage(ctx, rawMessage)
322369

323370
// Only send response if there is one (not for notifications)
@@ -384,6 +431,7 @@ func (s *SSEServer) SendEventToSession(
384431
return fmt.Errorf("event queue full")
385432
}
386433
}
434+
387435
func (s *SSEServer) GetUrlPath(input string) (string, error) {
388436
parse, err := url.Parse(input)
389437
if err != nil {
@@ -395,6 +443,7 @@ func (s *SSEServer) GetUrlPath(input string) (string, error) {
395443
func (s *SSEServer) CompleteSseEndpoint() string {
396444
return s.baseURL + s.basePath + s.sseEndpoint
397445
}
446+
398447
func (s *SSEServer) CompleteSsePath() string {
399448
path, err := s.GetUrlPath(s.CompleteSseEndpoint())
400449
if err != nil {
@@ -406,6 +455,7 @@ func (s *SSEServer) CompleteSsePath() string {
406455
func (s *SSEServer) CompleteMessageEndpoint() string {
407456
return s.baseURL + s.basePath + s.messageEndpoint
408457
}
458+
409459
func (s *SSEServer) CompleteMessagePath() string {
410460
path, err := s.GetUrlPath(s.CompleteMessageEndpoint())
411461
if err != nil {
@@ -417,17 +467,61 @@ func (s *SSEServer) CompleteMessagePath() string {
417467
// ServeHTTP implements the http.Handler interface.
418468
func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
419469
path := r.URL.Path
420-
// Use exact path matching rather than Contains
421-
ssePath := s.CompleteSsePath()
422-
if ssePath != "" && path == ssePath {
423-
s.handleSSE(w, r)
424-
return
425-
}
426470
messagePath := s.CompleteMessagePath()
471+
472+
// Handle message endpoint
427473
if messagePath != "" && path == messagePath {
428474
s.handleMessage(w, r)
429475
return
430476
}
431477

478+
// Handle SSE endpoint with route parameters
479+
if s.ssePattern != "" {
480+
// Try pattern matching if pattern is set
481+
fullPattern := s.basePath + s.ssePattern
482+
matches, params := matchPath(fullPattern, path)
483+
if matches {
484+
// Create new context with route parameters
485+
ctx := context.WithValue(r.Context(), RouteParamsKey{}, params)
486+
s.handleSSE(w, r.WithContext(ctx))
487+
return
488+
}
489+
// If pattern is set but doesn't match, return 404
490+
http.NotFound(w, r)
491+
return
492+
}
493+
494+
// If no pattern is set, use the default SSE endpoint
495+
ssePath := s.CompleteSsePath()
496+
if ssePath != "" && path == ssePath {
497+
s.handleSSE(w, r)
498+
return
499+
}
500+
432501
http.NotFound(w, r)
433502
}
503+
504+
// matchPath checks if the given path matches the pattern and extracts parameters
505+
// pattern format: /user/:id/profile/:type
506+
func matchPath(pattern, path string) (bool, RouteParams) {
507+
patternParts := strings.Split(strings.Trim(pattern, "/"), "/")
508+
pathParts := strings.Split(strings.Trim(path, "/"), "/")
509+
510+
if len(patternParts) != len(pathParts) {
511+
return false, nil
512+
}
513+
514+
params := make(RouteParams)
515+
for i, part := range patternParts {
516+
if strings.HasPrefix(part, ":") {
517+
// This is a parameter
518+
paramName := strings.TrimPrefix(part, ":")
519+
params[paramName] = pathParts[i]
520+
} else if part != pathParts[i] {
521+
// Static part doesn't match
522+
return false, nil
523+
}
524+
}
525+
526+
return true, params
527+
}

0 commit comments

Comments
 (0)