Skip to content

make context available in hooks, add OnRegisterSession hook #92

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions examples/everything/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,25 +32,25 @@ func NewMCPServer() *server.MCPServer {

hooks := &server.Hooks{}

hooks.AddBeforeAny(func(id any, method mcp.MCPMethod, message any) {
hooks.AddBeforeAny(func(ctx context.Context, id any, method mcp.MCPMethod, message any) {
fmt.Printf("beforeAny: %s, %v, %v\n", method, id, message)
})
hooks.AddOnSuccess(func(id any, method mcp.MCPMethod, message any, result any) {
hooks.AddOnSuccess(func(ctx context.Context, id any, method mcp.MCPMethod, message any, result any) {
fmt.Printf("onSuccess: %s, %v, %v, %v\n", method, id, message, result)
})
hooks.AddOnError(func(id any, method mcp.MCPMethod, message any, err error) {
hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) {
fmt.Printf("onError: %s, %v, %v, %v\n", method, id, message, err)
})
hooks.AddBeforeInitialize(func(id any, message *mcp.InitializeRequest) {
hooks.AddBeforeInitialize(func(ctx context.Context, id any, message *mcp.InitializeRequest) {
fmt.Printf("beforeInitialize: %v, %v\n", id, message)
})
hooks.AddAfterInitialize(func(id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) {
hooks.AddAfterInitialize(func(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) {
fmt.Printf("afterInitialize: %v, %v, %v\n", id, message, result)
})
hooks.AddAfterCallTool(func(id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) {
hooks.AddAfterCallTool(func(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) {
fmt.Printf("afterCallTool: %v, %v, %v\n", id, message, result)
})
hooks.AddBeforeCallTool(func(id any, message *mcp.CallToolRequest) {
hooks.AddBeforeCallTool(func(ctx context.Context, id any, message *mcp.CallToolRequest) {
fmt.Printf("beforeCallTool: %v, %v\n", id, message)
})

Expand Down
185 changes: 102 additions & 83 deletions server/hooks.go

Large diffs are not rendered by default.

56 changes: 37 additions & 19 deletions server/internal/gen/hooks.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,24 @@ import (
"github.com/mark3labs/mcp-go/mcp"
)

// OnRegisterSessionHookFunc is a hook that will be called when a new session is registered.
type OnRegisterSessionHookFunc func(ctx context.Context, session ClientSession)


// BeforeAnyHookFunc is a function that is called after the request is
// parsed but before the method is called.
type BeforeAnyHookFunc func(id any, method mcp.MCPMethod, message any)
type BeforeAnyHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any)

// OnSuccessHookFunc is a hook that will be called after the request
// successfully generates a result, but before the result is sent to the client.
type OnSuccessHookFunc func(id any, method mcp.MCPMethod, message any, result any)
type OnSuccessHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, result any)

// OnErrorHookFunc is a hook that will be called when an error occurs,
// either during the request parsing or the method execution.
//
// Example usage:
// ```
// hooks.AddOnError(func(id any, method mcp.MCPMethod, message any, err error) {
// hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) {
// // Check for specific error types using errors.Is
// if errors.Is(err, ErrUnsupported) {
// // Handle capability not supported errors
Expand All @@ -51,14 +55,15 @@ type OnSuccessHookFunc func(id any, method mcp.MCPMethod, message any, result an
// log.Printf("Tool not found: %v", err)
// }
// })
type OnErrorHookFunc func(id any, method mcp.MCPMethod, message any, err error)
type OnErrorHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error)

{{range .}}
type OnBefore{{.HookName}}Func func(id any, message *mcp.{{.ParamType}})
type OnAfter{{.HookName}}Func func(id any, message *mcp.{{.ParamType}}, result *mcp.{{.ResultType}})
type OnBefore{{.HookName}}Func func(ctx context.Context, id any, message *mcp.{{.ParamType}})
type OnAfter{{.HookName}}Func func(ctx context.Context, id any, message *mcp.{{.ParamType}}, result *mcp.{{.ResultType}})
{{end}}

type Hooks struct {
OnRegisterSession []OnRegisterSessionHookFunc
OnBeforeAny []BeforeAnyHookFunc
OnSuccess []OnSuccessHookFunc
OnError []OnErrorHookFunc
Expand Down Expand Up @@ -87,7 +92,7 @@ func (c *Hooks) AddOnSuccess(hook OnSuccessHookFunc) {
//
// // Register hook to capture and inspect errors
// hooks := &Hooks{}
// hooks.AddOnError(func(id any, method mcp.MCPMethod, message any, err error) {
// hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) {
// // For capability-related errors
// if errors.Is(err, ErrUnsupported) {
// // Handle capability not supported
Expand Down Expand Up @@ -124,21 +129,21 @@ func (c *Hooks) AddOnError(hook OnErrorHookFunc) {
c.OnError = append(c.OnError, hook)
}

func (c *Hooks) beforeAny(id any, method mcp.MCPMethod, message any) {
func (c *Hooks) beforeAny(ctx context.Context, id any, method mcp.MCPMethod, message any) {
if c == nil {
return
}
for _, hook := range c.OnBeforeAny {
hook(id, method, message)
hook(ctx, id, method, message)
}
}

func (c *Hooks) onSuccess(id any, method mcp.MCPMethod, message any, result any) {
func (c *Hooks) onSuccess(ctx context.Context, id any, method mcp.MCPMethod, message any, result any) {
if c == nil {
return
}
for _, hook := range c.OnSuccess {
hook(id, method, message, result)
hook(ctx, id, method, message, result)
}
}

Expand All @@ -156,15 +161,28 @@ func (c *Hooks) onSuccess(id any, method mcp.MCPMethod, message any, result any)
// - ErrResourceNotFound: When a resource is not found
// - ErrPromptNotFound: When a prompt is not found
// - ErrToolNotFound: When a tool is not found
func (c *Hooks) onError(id any, method mcp.MCPMethod, message any, err error) {
func (c *Hooks) onError(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) {
if c == nil {
return
}
for _, hook := range c.OnError {
hook(id, method, message, err)
hook(ctx, id, method, message, err)
}
}

func (c *Hooks) AddOnRegisterSession(hook OnRegisterSessionHookFunc) {
c.OnRegisterSession = append(c.OnRegisterSession, hook)
}

func (c *Hooks) RegisterSession(ctx context.Context, session ClientSession) {
if c == nil {
return
}
for _, hook := range c.OnRegisterSession {
hook(ctx, session)
}
}

{{- range .}}
func (c *Hooks) AddBefore{{.HookName}}(hook OnBefore{{.HookName}}Func) {
c.OnBefore{{.HookName}} = append(c.OnBefore{{.HookName}}, hook)
Expand All @@ -174,23 +192,23 @@ func (c *Hooks) AddAfter{{.HookName}}(hook OnAfter{{.HookName}}Func) {
c.OnAfter{{.HookName}} = append(c.OnAfter{{.HookName}}, hook)
}

func (c *Hooks) before{{.HookName}}(id any, message *mcp.{{.ParamType}}) {
c.beforeAny(id, mcp.{{.MethodName}}, message)
func (c *Hooks) before{{.HookName}}(ctx context.Context, id any, message *mcp.{{.ParamType}}) {
c.beforeAny(ctx, id, mcp.{{.MethodName}}, message)
if c == nil {
return
}
for _, hook := range c.OnBefore{{.HookName}} {
hook(id, message)
hook(ctx, id, message)
}
}

func (c *Hooks) after{{.HookName}}(id any, message *mcp.{{.ParamType}}, result *mcp.{{.ResultType}}) {
c.onSuccess(id, mcp.{{.MethodName}}, message, result)
func (c *Hooks) after{{.HookName}}(ctx context.Context, id any, message *mcp.{{.ParamType}}, result *mcp.{{.ResultType}}) {
c.onSuccess(ctx, id, mcp.{{.MethodName}}, message, result)
if c == nil {
return
}
for _, hook := range c.OnAfter{{.HookName}} {
hook(id, message, result)
hook(ctx, id, message, result)
}
}
{{- end -}}
6 changes: 3 additions & 3 deletions server/internal/gen/request_handler.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ func (s *MCPServer) HandleMessage(
err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method},
}
} else {
s.hooks.before{{.HookName}}(baseMessage.ID, &request)
s.hooks.before{{.HookName}}(ctx, baseMessage.ID, &request)
result, err = s.{{.HandlerFunc}}(ctx, baseMessage.ID, request)
}
if err != nil {
s.hooks.onError(baseMessage.ID, baseMessage.Method, &request, err)
s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err)
return err.ToJSONRPCError()
}
s.hooks.after{{.HookName}}(baseMessage.ID, &request, result)
s.hooks.after{{.HookName}}(ctx, baseMessage.ID, &request, result)
return createResponse(baseMessage.ID, *result)
{{- end }}
default:
Expand Down
54 changes: 27 additions & 27 deletions server/request_handler.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,14 @@ func (s *MCPServer) WithContext(

// RegisterSession saves session that should be notified in case if some server attributes changed.
func (s *MCPServer) RegisterSession(
ctx context.Context,
session ClientSession,
) error {
sessionID := session.SessionID()
if _, exists := s.sessions.LoadOrStore(sessionID, session); exists {
return fmt.Errorf("session %s is already registered", sessionID)
}
s.hooks.RegisterSession(ctx, session)
return nil
}

Expand Down
Loading