@@ -11,20 +11,24 @@ import (
11
11
"github.com/mark3labs/mcp-go/mcp"
12
12
)
13
13
14
+ // OnRegisterSessionHookFunc is a hook that will be called when a new session is registered.
15
+ type OnRegisterSessionHookFunc func(ctx context.Context, session ClientSession)
16
+
17
+
14
18
// BeforeAnyHookFunc is a function that is called after the request is
15
19
// parsed but before the method is called.
16
- type BeforeAnyHookFunc func(id any, method mcp.MCPMethod, message any)
20
+ type BeforeAnyHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any)
17
21
18
22
// OnSuccessHookFunc is a hook that will be called after the request
19
23
// successfully generates a result, but before the result is sent to the client.
20
- type OnSuccessHookFunc func(id any, method mcp.MCPMethod, message any, result any)
24
+ type OnSuccessHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, result any)
21
25
22
26
// OnErrorHookFunc is a hook that will be called when an error occurs,
23
27
// either during the request parsing or the method execution.
24
28
//
25
29
// Example usage:
26
30
// ```
27
- // hooks.AddOnError(func(id any, method mcp.MCPMethod, message any, err error) {
31
+ // hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) {
28
32
// // Check for specific error types using errors.Is
29
33
// if errors.Is(err, ErrUnsupported) {
30
34
// // Handle capability not supported errors
@@ -51,14 +55,15 @@ type OnSuccessHookFunc func(id any, method mcp.MCPMethod, message any, result an
51
55
// log.Printf("Tool not found: %v", err)
52
56
// }
53
57
// })
54
- type OnErrorHookFunc func(id any, method mcp.MCPMethod, message any, err error)
58
+ type OnErrorHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error)
55
59
56
60
{{range .}}
57
- type OnBefore{{.HookName}}Func func(id any, message *mcp.{{.ParamType}})
58
- type OnAfter{{.HookName}}Func func(id any, message *mcp.{{.ParamType}}, result *mcp.{{.ResultType}})
61
+ type OnBefore{{.HookName}}Func func(ctx context.Context, id any, message *mcp.{{.ParamType}})
62
+ type OnAfter{{.HookName}}Func func(ctx context.Context, id any, message *mcp.{{.ParamType}}, result *mcp.{{.ResultType}})
59
63
{{end}}
60
64
61
65
type Hooks struct {
66
+ OnRegisterSession []OnRegisterSessionHookFunc
62
67
OnBeforeAny []BeforeAnyHookFunc
63
68
OnSuccess []OnSuccessHookFunc
64
69
OnError []OnErrorHookFunc
@@ -87,7 +92,7 @@ func (c *Hooks) AddOnSuccess(hook OnSuccessHookFunc) {
87
92
//
88
93
// // Register hook to capture and inspect errors
89
94
// hooks := &Hooks{}
90
- // hooks.AddOnError(func(id any, method mcp.MCPMethod, message any, err error) {
95
+ // hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) {
91
96
// // For capability-related errors
92
97
// if errors.Is(err, ErrUnsupported) {
93
98
// // Handle capability not supported
@@ -124,21 +129,21 @@ func (c *Hooks) AddOnError(hook OnErrorHookFunc) {
124
129
c.OnError = append(c.OnError, hook)
125
130
}
126
131
127
- func (c *Hooks) beforeAny(id any, method mcp.MCPMethod, message any) {
132
+ func (c *Hooks) beforeAny(ctx context.Context, id any, method mcp.MCPMethod, message any) {
128
133
if c == nil {
129
134
return
130
135
}
131
136
for _, hook := range c.OnBeforeAny {
132
- hook(id, method, message)
137
+ hook(ctx, id, method, message)
133
138
}
134
139
}
135
140
136
- func (c *Hooks) onSuccess(id any, method mcp.MCPMethod, message any, result any) {
141
+ func (c *Hooks) onSuccess(ctx context.Context, id any, method mcp.MCPMethod, message any, result any) {
137
142
if c == nil {
138
143
return
139
144
}
140
145
for _, hook := range c.OnSuccess {
141
- hook(id, method, message, result)
146
+ hook(ctx, id, method, message, result)
142
147
}
143
148
}
144
149
@@ -156,15 +161,28 @@ func (c *Hooks) onSuccess(id any, method mcp.MCPMethod, message any, result any)
156
161
// - ErrResourceNotFound: When a resource is not found
157
162
// - ErrPromptNotFound: When a prompt is not found
158
163
// - ErrToolNotFound: When a tool is not found
159
- func (c *Hooks) onError(id any, method mcp.MCPMethod, message any, err error) {
164
+ func (c *Hooks) onError(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) {
160
165
if c == nil {
161
166
return
162
167
}
163
168
for _, hook := range c.OnError {
164
- hook(id, method, message, err)
169
+ hook(ctx, id, method, message, err)
165
170
}
166
171
}
167
172
173
+ func (c *Hooks) AddOnRegisterSession(hook OnRegisterSessionHookFunc) {
174
+ c.OnRegisterSession = append(c.OnRegisterSession, hook)
175
+ }
176
+
177
+ func (c *Hooks) RegisterSession(ctx context.Context, session ClientSession) {
178
+ if c == nil {
179
+ return
180
+ }
181
+ for _, hook := range c.OnRegisterSession {
182
+ hook(ctx, session)
183
+ }
184
+ }
185
+
168
186
{{- range .}}
169
187
func (c *Hooks) AddBefore{{.HookName}}(hook OnBefore{{.HookName}}Func) {
170
188
c.OnBefore{{.HookName}} = append(c.OnBefore{{.HookName}}, hook)
@@ -174,23 +192,23 @@ func (c *Hooks) AddAfter{{.HookName}}(hook OnAfter{{.HookName}}Func) {
174
192
c.OnAfter{{.HookName}} = append(c.OnAfter{{.HookName}}, hook)
175
193
}
176
194
177
- func (c *Hooks) before{{.HookName}}(id any, message *mcp.{{.ParamType}}) {
178
- c.beforeAny(id, mcp.{{.MethodName}}, message)
195
+ func (c *Hooks) before{{.HookName}}(ctx context.Context, id any, message *mcp.{{.ParamType}}) {
196
+ c.beforeAny(ctx, id, mcp.{{.MethodName}}, message)
179
197
if c == nil {
180
198
return
181
199
}
182
200
for _, hook := range c.OnBefore{{.HookName}} {
183
- hook(id, message)
201
+ hook(ctx, id, message)
184
202
}
185
203
}
186
204
187
- func (c *Hooks) after{{.HookName}}(id any, message *mcp.{{.ParamType}}, result *mcp.{{.ResultType}}) {
188
- c.onSuccess(id, mcp.{{.MethodName}}, message, result)
205
+ func (c *Hooks) after{{.HookName}}(ctx context.Context, id any, message *mcp.{{.ParamType}}, result *mcp.{{.ResultType}}) {
206
+ c.onSuccess(ctx, id, mcp.{{.MethodName}}, message, result)
189
207
if c == nil {
190
208
return
191
209
}
192
210
for _, hook := range c.OnAfter{{.HookName}} {
193
- hook(id, message, result)
211
+ hook(ctx, id, message, result)
194
212
}
195
213
}
196
214
{{- end -}}
0 commit comments