Skip to content

Commit 5d2c4e4

Browse files
committed
feat: add context to hooks
feat: new hook: OnRegisterSession
1 parent 051cda5 commit 5d2c4e4

File tree

9 files changed

+197
-158
lines changed

9 files changed

+197
-158
lines changed

examples/everything/main.go

+7-7
Original file line numberDiff line numberDiff line change
@@ -32,25 +32,25 @@ func NewMCPServer() *server.MCPServer {
3232

3333
hooks := &server.Hooks{}
3434

35-
hooks.AddBeforeAny(func(id any, method mcp.MCPMethod, message any) {
35+
hooks.AddBeforeAny(func(ctx context.Context, id any, method mcp.MCPMethod, message any) {
3636
fmt.Printf("beforeAny: %s, %v, %v\n", method, id, message)
3737
})
38-
hooks.AddOnSuccess(func(id any, method mcp.MCPMethod, message any, result any) {
38+
hooks.AddOnSuccess(func(ctx context.Context, id any, method mcp.MCPMethod, message any, result any) {
3939
fmt.Printf("onSuccess: %s, %v, %v, %v\n", method, id, message, result)
4040
})
41-
hooks.AddOnError(func(id any, method mcp.MCPMethod, message any, err error) {
41+
hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) {
4242
fmt.Printf("onError: %s, %v, %v, %v\n", method, id, message, err)
4343
})
44-
hooks.AddBeforeInitialize(func(id any, message *mcp.InitializeRequest) {
44+
hooks.AddBeforeInitialize(func(ctx context.Context, id any, message *mcp.InitializeRequest) {
4545
fmt.Printf("beforeInitialize: %v, %v\n", id, message)
4646
})
47-
hooks.AddAfterInitialize(func(id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) {
47+
hooks.AddAfterInitialize(func(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) {
4848
fmt.Printf("afterInitialize: %v, %v, %v\n", id, message, result)
4949
})
50-
hooks.AddAfterCallTool(func(id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) {
50+
hooks.AddAfterCallTool(func(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) {
5151
fmt.Printf("afterCallTool: %v, %v, %v\n", id, message, result)
5252
})
53-
hooks.AddBeforeCallTool(func(id any, message *mcp.CallToolRequest) {
53+
hooks.AddBeforeCallTool(func(ctx context.Context, id any, message *mcp.CallToolRequest) {
5454
fmt.Printf("beforeCallTool: %v, %v\n", id, message)
5555
})
5656

server/hooks.go

+102-83
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

server/internal/gen/hooks.go.tmpl

+37-19
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,24 @@ import (
1111
"github.com/mark3labs/mcp-go/mcp"
1212
)
1313

14+
// OnRegisterSessionHookFunc is a hook that will be called when a new session is registered.
15+
type OnRegisterSessionHookFunc func(ctx context.Context, session ClientSession)
16+
17+
1418
// BeforeAnyHookFunc is a function that is called after the request is
1519
// parsed but before the method is called.
16-
type BeforeAnyHookFunc func(id any, method mcp.MCPMethod, message any)
20+
type BeforeAnyHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any)
1721

1822
// OnSuccessHookFunc is a hook that will be called after the request
1923
// successfully generates a result, but before the result is sent to the client.
20-
type OnSuccessHookFunc func(id any, method mcp.MCPMethod, message any, result any)
24+
type OnSuccessHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, result any)
2125

2226
// OnErrorHookFunc is a hook that will be called when an error occurs,
2327
// either during the request parsing or the method execution.
2428
//
2529
// Example usage:
2630
// ```
27-
// hooks.AddOnError(func(id any, method mcp.MCPMethod, message any, err error) {
31+
// hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) {
2832
// // Check for specific error types using errors.Is
2933
// if errors.Is(err, ErrUnsupported) {
3034
// // Handle capability not supported errors
@@ -51,14 +55,15 @@ type OnSuccessHookFunc func(id any, method mcp.MCPMethod, message any, result an
5155
// log.Printf("Tool not found: %v", err)
5256
// }
5357
// })
54-
type OnErrorHookFunc func(id any, method mcp.MCPMethod, message any, err error)
58+
type OnErrorHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error)
5559

5660
{{range .}}
57-
type OnBefore{{.HookName}}Func func(id any, message *mcp.{{.ParamType}})
58-
type OnAfter{{.HookName}}Func func(id any, message *mcp.{{.ParamType}}, result *mcp.{{.ResultType}})
61+
type OnBefore{{.HookName}}Func func(ctx context.Context, id any, message *mcp.{{.ParamType}})
62+
type OnAfter{{.HookName}}Func func(ctx context.Context, id any, message *mcp.{{.ParamType}}, result *mcp.{{.ResultType}})
5963
{{end}}
6064

6165
type Hooks struct {
66+
OnRegisterSession []OnRegisterSessionHookFunc
6267
OnBeforeAny []BeforeAnyHookFunc
6368
OnSuccess []OnSuccessHookFunc
6469
OnError []OnErrorHookFunc
@@ -87,7 +92,7 @@ func (c *Hooks) AddOnSuccess(hook OnSuccessHookFunc) {
8792
//
8893
// // Register hook to capture and inspect errors
8994
// hooks := &Hooks{}
90-
// hooks.AddOnError(func(id any, method mcp.MCPMethod, message any, err error) {
95+
// hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) {
9196
// // For capability-related errors
9297
// if errors.Is(err, ErrUnsupported) {
9398
// // Handle capability not supported
@@ -124,21 +129,21 @@ func (c *Hooks) AddOnError(hook OnErrorHookFunc) {
124129
c.OnError = append(c.OnError, hook)
125130
}
126131

127-
func (c *Hooks) beforeAny(id any, method mcp.MCPMethod, message any) {
132+
func (c *Hooks) beforeAny(ctx context.Context, id any, method mcp.MCPMethod, message any) {
128133
if c == nil {
129134
return
130135
}
131136
for _, hook := range c.OnBeforeAny {
132-
hook(id, method, message)
137+
hook(ctx, id, method, message)
133138
}
134139
}
135140

136-
func (c *Hooks) onSuccess(id any, method mcp.MCPMethod, message any, result any) {
141+
func (c *Hooks) onSuccess(ctx context.Context, id any, method mcp.MCPMethod, message any, result any) {
137142
if c == nil {
138143
return
139144
}
140145
for _, hook := range c.OnSuccess {
141-
hook(id, method, message, result)
146+
hook(ctx, id, method, message, result)
142147
}
143148
}
144149

@@ -156,15 +161,28 @@ func (c *Hooks) onSuccess(id any, method mcp.MCPMethod, message any, result any)
156161
// - ErrResourceNotFound: When a resource is not found
157162
// - ErrPromptNotFound: When a prompt is not found
158163
// - ErrToolNotFound: When a tool is not found
159-
func (c *Hooks) onError(id any, method mcp.MCPMethod, message any, err error) {
164+
func (c *Hooks) onError(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) {
160165
if c == nil {
161166
return
162167
}
163168
for _, hook := range c.OnError {
164-
hook(id, method, message, err)
169+
hook(ctx, id, method, message, err)
165170
}
166171
}
167172

173+
func (c *Hooks) AddOnRegisterSession(hook OnRegisterSessionHookFunc) {
174+
c.OnRegisterSession = append(c.OnRegisterSession, hook)
175+
}
176+
177+
func (c *Hooks) RegisterSession(ctx context.Context, session ClientSession) {
178+
if c == nil {
179+
return
180+
}
181+
for _, hook := range c.OnRegisterSession {
182+
hook(ctx, session)
183+
}
184+
}
185+
168186
{{- range .}}
169187
func (c *Hooks) AddBefore{{.HookName}}(hook OnBefore{{.HookName}}Func) {
170188
c.OnBefore{{.HookName}} = append(c.OnBefore{{.HookName}}, hook)
@@ -174,23 +192,23 @@ func (c *Hooks) AddAfter{{.HookName}}(hook OnAfter{{.HookName}}Func) {
174192
c.OnAfter{{.HookName}} = append(c.OnAfter{{.HookName}}, hook)
175193
}
176194

177-
func (c *Hooks) before{{.HookName}}(id any, message *mcp.{{.ParamType}}) {
178-
c.beforeAny(id, mcp.{{.MethodName}}, message)
195+
func (c *Hooks) before{{.HookName}}(ctx context.Context, id any, message *mcp.{{.ParamType}}) {
196+
c.beforeAny(ctx, id, mcp.{{.MethodName}}, message)
179197
if c == nil {
180198
return
181199
}
182200
for _, hook := range c.OnBefore{{.HookName}} {
183-
hook(id, message)
201+
hook(ctx, id, message)
184202
}
185203
}
186204

187-
func (c *Hooks) after{{.HookName}}(id any, message *mcp.{{.ParamType}}, result *mcp.{{.ResultType}}) {
188-
c.onSuccess(id, mcp.{{.MethodName}}, message, result)
205+
func (c *Hooks) after{{.HookName}}(ctx context.Context, id any, message *mcp.{{.ParamType}}, result *mcp.{{.ResultType}}) {
206+
c.onSuccess(ctx, id, mcp.{{.MethodName}}, message, result)
189207
if c == nil {
190208
return
191209
}
192210
for _, hook := range c.OnAfter{{.HookName}} {
193-
hook(id, message, result)
211+
hook(ctx, id, message, result)
194212
}
195213
}
196214
{{- end -}}

server/internal/gen/request_handler.go.tmpl

+3-3
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,14 @@ func (s *MCPServer) HandleMessage(
7474
err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method},
7575
}
7676
} else {
77-
s.hooks.before{{.HookName}}(baseMessage.ID, &request)
77+
s.hooks.before{{.HookName}}(ctx, baseMessage.ID, &request)
7878
result, err = s.{{.HandlerFunc}}(ctx, baseMessage.ID, request)
7979
}
8080
if err != nil {
81-
s.hooks.onError(baseMessage.ID, baseMessage.Method, &request, err)
81+
s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err)
8282
return err.ToJSONRPCError()
8383
}
84-
s.hooks.after{{.HookName}}(baseMessage.ID, &request, result)
84+
s.hooks.after{{.HookName}}(ctx, baseMessage.ID, &request, result)
8585
return createResponse(baseMessage.ID, *result)
8686
{{- end }}
8787
default:

server/request_handler.go

+27-27
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

server/server.go

+2
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,14 @@ func (s *MCPServer) WithContext(
172172

173173
// RegisterSession saves session that should be notified in case if some server attributes changed.
174174
func (s *MCPServer) RegisterSession(
175+
ctx context.Context,
175176
session ClientSession,
176177
) error {
177178
sessionID := session.SessionID()
178179
if _, exists := s.sessions.LoadOrStore(sessionID, session); exists {
179180
return fmt.Errorf("session %s is already registered", sessionID)
180181
}
182+
s.hooks.RegisterSession(ctx, session)
181183
return nil
182184
}
183185

0 commit comments

Comments
 (0)