@@ -24,13 +24,36 @@ type sseSession struct {
24
24
sessionID string
25
25
notificationChannel chan mcp.JSONRPCNotification
26
26
initialized atomic.Bool
27
+ routeParams RouteParams // Store route parameters in session
27
28
}
28
29
29
30
// SSEContextFunc is a function that takes an existing context and the current
30
31
// request and returns a potentially modified context based on the request
31
32
// content. This can be used to inject context values from headers, for example.
32
33
type SSEContextFunc func (ctx context.Context , r * http.Request ) context.Context
33
34
35
+ // RouteParamsKey is the key type for storing route parameters in context
36
+ type RouteParamsKey struct {}
37
+
38
+ // RouteParams stores path parameters
39
+ type RouteParams map [string ]string
40
+
41
+ // GetRouteParam retrieves a route parameter from context
42
+ func GetRouteParam (ctx context.Context , key string ) string {
43
+ if params , ok := ctx .Value (RouteParamsKey {}).(RouteParams ); ok {
44
+ return params [key ]
45
+ }
46
+ return ""
47
+ }
48
+
49
+ // GetRouteParams retrieves all route parameters from context
50
+ func GetRouteParams (ctx context.Context ) RouteParams {
51
+ if params , ok := ctx .Value (RouteParamsKey {}).(RouteParams ); ok {
52
+ return params
53
+ }
54
+ return RouteParams {}
55
+ }
56
+
34
57
func (s * sseSession ) SessionID () string {
35
58
return s .sessionID
36
59
}
@@ -58,6 +81,7 @@ type SSEServer struct {
58
81
messageEndpoint string
59
82
useFullURLForMessageEndpoint bool
60
83
sseEndpoint string
84
+ ssePattern string
61
85
sessions sync.Map
62
86
srv * http.Server
63
87
contextFunc SSEContextFunc
@@ -123,14 +147,21 @@ func WithSSEEndpoint(endpoint string) SSEOption {
123
147
}
124
148
}
125
149
150
+ // WithSSEPattern sets the SSE endpoint pattern with route parameters
151
+ func WithSSEPattern (pattern string ) SSEOption {
152
+ return func (s * SSEServer ) {
153
+ s .ssePattern = pattern
154
+ }
155
+ }
156
+
126
157
// WithHTTPServer sets the HTTP server instance
127
158
func WithHTTPServer (srv * http.Server ) SSEOption {
128
159
return func (s * SSEServer ) {
129
160
s .srv = srv
130
161
}
131
162
}
132
163
133
- // WithContextFunc sets a function that will be called to customise the context
164
+ // WithSSEContextFunc sets a function that will be called to customise the context
134
165
// to the server using the incoming request.
135
166
func WithSSEContextFunc (fn SSEContextFunc ) SSEOption {
136
167
return func (s * SSEServer ) {
@@ -222,12 +253,21 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
222
253
eventQueue : make (chan string , 100 ), // Buffer for events
223
254
sessionID : sessionID ,
224
255
notificationChannel : make (chan mcp.JSONRPCNotification , 100 ),
256
+ routeParams : GetRouteParams (r .Context ()), // Store route parameters from context
225
257
}
226
258
227
259
s .sessions .Store (sessionID , session )
228
260
defer s .sessions .Delete (sessionID )
229
261
230
- if err := s .server .RegisterSession (r .Context (), session ); err != nil {
262
+ // Create base context with session
263
+ ctx := s .server .WithContext (r .Context (), session )
264
+
265
+ // Apply custom context function if set
266
+ if s .contextFunc != nil {
267
+ ctx = s .contextFunc (ctx , r )
268
+ }
269
+
270
+ if err := s .server .RegisterSession (ctx , session ); err != nil {
231
271
http .Error (w , fmt .Sprintf ("Session registration failed: %v" , err ), http .StatusInternalServerError )
232
272
return
233
273
}
@@ -249,7 +289,7 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
249
289
}
250
290
case <- session .done :
251
291
return
252
- case <- r . Context () .Done ():
292
+ case <- ctx .Done ():
253
293
return
254
294
}
255
295
}
@@ -266,7 +306,7 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
266
306
// Write the event to the response
267
307
fmt .Fprint (w , event )
268
308
flusher .Flush ()
269
- case <- r . Context () .Done ():
309
+ case <- ctx .Done ():
270
310
close (session .done )
271
311
return
272
312
}
@@ -304,8 +344,15 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
304
344
}
305
345
session := sessionI .(* sseSession )
306
346
307
- // Set the client context before handling the message
347
+ // Create base context with session
308
348
ctx := s .server .WithContext (r .Context (), session )
349
+
350
+ // Add stored route parameters to context
351
+ if len (session .routeParams ) > 0 {
352
+ ctx = context .WithValue (ctx , RouteParamsKey {}, session .routeParams )
353
+ }
354
+
355
+ // Apply custom context function if set
309
356
if s .contextFunc != nil {
310
357
ctx = s .contextFunc (ctx , r )
311
358
}
@@ -317,7 +364,7 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
317
364
return
318
365
}
319
366
320
- // Process message through MCPServer
367
+ // Process message through MCPServer with the context containing route parameters
321
368
response := s .server .HandleMessage (ctx , rawMessage )
322
369
323
370
// Only send response if there is one (not for notifications)
@@ -384,6 +431,7 @@ func (s *SSEServer) SendEventToSession(
384
431
return fmt .Errorf ("event queue full" )
385
432
}
386
433
}
434
+
387
435
func (s * SSEServer ) GetUrlPath (input string ) (string , error ) {
388
436
parse , err := url .Parse (input )
389
437
if err != nil {
@@ -395,6 +443,7 @@ func (s *SSEServer) GetUrlPath(input string) (string, error) {
395
443
func (s * SSEServer ) CompleteSseEndpoint () string {
396
444
return s .baseURL + s .basePath + s .sseEndpoint
397
445
}
446
+
398
447
func (s * SSEServer ) CompleteSsePath () string {
399
448
path , err := s .GetUrlPath (s .CompleteSseEndpoint ())
400
449
if err != nil {
@@ -406,6 +455,7 @@ func (s *SSEServer) CompleteSsePath() string {
406
455
func (s * SSEServer ) CompleteMessageEndpoint () string {
407
456
return s .baseURL + s .basePath + s .messageEndpoint
408
457
}
458
+
409
459
func (s * SSEServer ) CompleteMessagePath () string {
410
460
path , err := s .GetUrlPath (s .CompleteMessageEndpoint ())
411
461
if err != nil {
@@ -417,17 +467,61 @@ func (s *SSEServer) CompleteMessagePath() string {
417
467
// ServeHTTP implements the http.Handler interface.
418
468
func (s * SSEServer ) ServeHTTP (w http.ResponseWriter , r * http.Request ) {
419
469
path := r .URL .Path
420
- // Use exact path matching rather than Contains
421
- ssePath := s .CompleteSsePath ()
422
- if ssePath != "" && path == ssePath {
423
- s .handleSSE (w , r )
424
- return
425
- }
426
470
messagePath := s .CompleteMessagePath ()
471
+
472
+ // Handle message endpoint
427
473
if messagePath != "" && path == messagePath {
428
474
s .handleMessage (w , r )
429
475
return
430
476
}
431
477
478
+ // Handle SSE endpoint with route parameters
479
+ if s .ssePattern != "" {
480
+ // Try pattern matching if pattern is set
481
+ fullPattern := s .basePath + s .ssePattern
482
+ matches , params := matchPath (fullPattern , path )
483
+ if matches {
484
+ // Create new context with route parameters
485
+ ctx := context .WithValue (r .Context (), RouteParamsKey {}, params )
486
+ s .handleSSE (w , r .WithContext (ctx ))
487
+ return
488
+ }
489
+ // If pattern is set but doesn't match, return 404
490
+ http .NotFound (w , r )
491
+ return
492
+ }
493
+
494
+ // If no pattern is set, use the default SSE endpoint
495
+ ssePath := s .CompleteSsePath ()
496
+ if ssePath != "" && path == ssePath {
497
+ s .handleSSE (w , r )
498
+ return
499
+ }
500
+
432
501
http .NotFound (w , r )
433
502
}
503
+
504
+ // matchPath checks if the given path matches the pattern and extracts parameters
505
+ // pattern format: /user/:id/profile/:type
506
+ func matchPath (pattern , path string ) (bool , RouteParams ) {
507
+ patternParts := strings .Split (strings .Trim (pattern , "/" ), "/" )
508
+ pathParts := strings .Split (strings .Trim (path , "/" ), "/" )
509
+
510
+ if len (patternParts ) != len (pathParts ) {
511
+ return false , nil
512
+ }
513
+
514
+ params := make (RouteParams )
515
+ for i , part := range patternParts {
516
+ if strings .HasPrefix (part , ":" ) {
517
+ // This is a parameter
518
+ paramName := strings .TrimPrefix (part , ":" )
519
+ params [paramName ] = pathParts [i ]
520
+ } else if part != pathParts [i ] {
521
+ // Static part doesn't match
522
+ return false , nil
523
+ }
524
+ }
525
+
526
+ return true , params
527
+ }
0 commit comments