diff --git a/client.go b/client.go index 672e217..d2dcf39 100644 --- a/client.go +++ b/client.go @@ -17,12 +17,26 @@ type Client struct { initialized bool } +type ClientOptions func(*Client) + +func WithNotificationHandler( + method string, handler func(notification *transport.BaseJSONRPCNotification) error, +) ClientOptions { + return func(c *Client) { + c.protocol.SetNotificationHandler(method, handler) + } +} + // NewClient creates a new MCP client with the specified transport -func NewClient(transport transport.Transport) *Client { - return &Client{ +func NewClient(transport transport.Transport, options ...ClientOptions) *Client { + client := &Client{ transport: transport, protocol: protocol.NewProtocol(nil), } + for _, option := range options { + option(client) + } + return client } // Initialize connects to the server and retrieves its capabilities @@ -248,6 +262,19 @@ func (c *Client) ReadResource(ctx context.Context, uri string) (*ResourceRespons return resourceResponse.Response, nil } +func (c *Client) SetLoggingLevel(ctx context.Context, level Level) error { + if !c.initialized { + return errors.New("client not initialized") + } + params := map[string]Level{ + "level": level, + } + if _, err := c.protocol.Request(ctx, "logging/setLevel", params, nil); err != nil { + return errors.Wrap(err, "failed to set logging level") + } + return nil +} + // Ping sends a ping request to the server to check connectivity func (c *Client) Ping(ctx context.Context) error { if !c.initialized { diff --git a/examples/log_example/stdio/main.go b/examples/log_example/stdio/main.go new file mode 100644 index 0000000..4e23ec4 --- /dev/null +++ b/examples/log_example/stdio/main.go @@ -0,0 +1,76 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "log" + "os/exec" + "time" + + mcp "github.com/metoro-io/mcp-golang" + "github.com/metoro-io/mcp-golang/transport" + "github.com/metoro-io/mcp-golang/transport/stdio" +) + +func main() { + // Start the server process + cmd := exec.Command("go", "run", "./server/main.go") + stdin, err := cmd.StdinPipe() + if err != nil { + log.Fatalf("Failed to get stdin pipe: %v", err) + } + stdout, err := cmd.StdoutPipe() + if err != nil { + log.Fatalf("Failed to get stdout pipe: %v", err) + } + + if err := cmd.Start(); err != nil { + log.Fatalf("Failed to start server: %v", err) + } + defer cmd.Process.Kill() + + clientTransport := stdio.NewStdioServerTransportWithIO(stdout, stdin) + client := mcp.NewClient(clientTransport, mcp.WithNotificationHandler("notifications/message", func(notification *transport.BaseJSONRPCNotification) error { + var params struct { + Level string `json:"level" yaml:"level" mapstructure:"level"` + Logger string `json:"logger" yaml:"logger" mapstructure:"logger"` + Data interface{} `json:"data" yaml:"data" mapstructure:"data"` + } + if err := json.Unmarshal(notification.Params, ¶ms); err != nil { + log.Println("failed to unmarshal log_example params:", err.Error()) + return fmt.Errorf("failed to unmarshal log_example params: %w", err) + } + log.Printf("[%s] Notification: %s", params.Level, params.Data) + return nil + })) + + if _, err := client.Initialize(context.Background()); err != nil { + log.Fatalf("Failed to initialize client: %v", err) + } + + for _, level := range []mcp.Level{ + mcp.LevelDebug, + mcp.LevelInfo, + mcp.LevelNotice, + mcp.LevelWarning, + mcp.LevelError, + mcp.LevelCritical, + mcp.LevelAlert, + mcp.LevelEmergency, + } { + if err := client.SetLoggingLevel(context.Background(), level); err != nil { + log.Fatalf("Failed to set logging level: %v", err) + } + args := map[string]interface{}{ + "name": "World", + } + _, err := client.CallTool(context.Background(), "log", args) + if err != nil { + log.Printf("Failed to call log tool: %v", err) + } + // wait all notifications arrive + time.Sleep(3 * time.Second) + log.Println("----------------------------------") + } +} diff --git a/examples/log_example/stdio/server/main.go b/examples/log_example/stdio/server/main.go new file mode 100644 index 0000000..97a32eb --- /dev/null +++ b/examples/log_example/stdio/server/main.go @@ -0,0 +1,66 @@ +package main + +import ( + "log" + "os" + + mcp "github.com/metoro-io/mcp-golang" + "github.com/metoro-io/mcp-golang/transport/stdio" +) + +func init() { + logFile, _ := os.OpenFile("./server.log", os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) + log.SetOutput(logFile) +} + +// HelloArgs ... +type HelloArgs struct { + Name string `json:"name" jsonschema:"required,description=The name to say hello to"` +} + +func main() { + // Create a transport for the server + serverTransport := stdio.NewStdioServerTransport() + + // Create a new server with the transport + server := mcp.NewServer(serverTransport, mcp.WithLoggingCapability()) + + if err := server.RegisterTool("log", "get some log", func(k HelloArgs) (*mcp.ToolResponse, error) { + if err := server.SendLogMessageNotification(mcp.LevelDebug, "server", "debug"); err != nil { + log.Panic(err) + } + if err := server.SendLogMessageNotification(mcp.LevelInfo, "server", "info"); err != nil { + log.Panic(err) + } + if err := server.SendLogMessageNotification(mcp.LevelNotice, "server", "notice"); err != nil { + log.Panic(err) + } + if err := server.SendLogMessageNotification(mcp.LevelWarning, "server", "warning"); err != nil { + log.Panic(err) + } + if err := server.SendLogMessageNotification(mcp.LevelError, "server", "error"); err != nil { + log.Panic(err) + } + if err := server.SendLogMessageNotification(mcp.LevelCritical, "server", "critical"); err != nil { + log.Panic(err) + } + if err := server.SendLogMessageNotification(mcp.LevelAlert, "server", "alert"); err != nil { + log.Panic(err) + } + if err := server.SendLogMessageNotification(mcp.LevelEmergency, "server", "emergency"); err != nil { + log.Panic(err) + } + return &mcp.ToolResponse{}, nil + }); err != nil { + log.Panic(err) + } + + // Start the server + if err := server.Serve(); err != nil { + log.Printf("failed to serve: %v\n", err) + panic(err) + } + log.Println("server running") + // Keep the server running + select {} +} diff --git a/go.mod b/go.mod index 9e12af5..0a41953 100644 --- a/go.mod +++ b/go.mod @@ -32,6 +32,8 @@ require ( github.com/tidwall/pretty v1.2.1 // indirect github.com/ugorji/go/codec v1.2.7 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + go.uber.org/multierr v1.10.0 // indirect + go.uber.org/zap v1.27.0 // indirect golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 // indirect golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 // indirect golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069 // indirect diff --git a/go.sum b/go.sum index 309c382..1d3a58a 100644 --- a/go.sum +++ b/go.sum @@ -76,6 +76,8 @@ github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0 github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= +go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index 3172ac4..270ed3b 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -69,6 +69,7 @@ import ( "context" "encoding/json" "fmt" + "log" "sync" "time" @@ -158,6 +159,7 @@ func NewProtocol(options *ProtocolOptions) *Protocol { // Set up default handlers p.SetNotificationHandler("notifications/cancelled", p.handleCancelledNotification) + p.SetNotificationHandler("notifications/message", p.messageNotificationHandler) p.SetNotificationHandler("$/progress", p.handleProgressNotification) return p @@ -297,6 +299,20 @@ func (p *Protocol) handleRequest(ctx context.Context, request *transport.BaseJSO }() } +func (p *Protocol) messageNotificationHandler(notification *transport.BaseJSONRPCNotification) error { + var params struct { + Level string `json:"level" yaml:"level" mapstructure:"level"` + Logger string `json:"logger" yaml:"logger" mapstructure:"logger"` + Data interface{} `json:"data" yaml:"data" mapstructure:"data"` + } + if err := json.Unmarshal(notification.Params, ¶ms); err != nil { + log.Println("failed to unmarshal log_example params:", err.Error()) + return fmt.Errorf("failed to unmarshal log_example params: %w", err) + } + log.Printf("[%s] %s\n", params.Level, params.Data.(string)) + return nil +} + func (p *Protocol) handleProgressNotification(notification *transport.BaseJSONRPCNotification) error { var params struct { Progress int64 `json:"progress"` diff --git a/server.go b/server.go index 84fb08b..28760e1 100644 --- a/server.go +++ b/server.go @@ -8,6 +8,7 @@ import ( "reflect" "sort" "strings" + "sync" "github.com/invopop/jsonschema" "github.com/metoro-io/mcp-golang/internal/datastructures" @@ -100,6 +101,7 @@ func (c promptResponseSent) MarshalJSON() ([]byte, error) { } type Server struct { + mu sync.RWMutex isRunning bool transport transport.Transport protocol *protocol.Protocol @@ -110,6 +112,8 @@ type Server struct { serverInstructions *string serverName string serverVersion string + capabilities ServerCapabilities + loggingLevel Level } type prompt struct { @@ -142,6 +146,14 @@ func WithProtocol(protocol *protocol.Protocol) ServerOptions { } } +func WithLoggingCapability() ServerOptions { + return func(s *Server) { + s.capabilities.Logging = map[string]interface{}{ + "enable": true, + } + } +} + // Beware: As of 2024-12-13, it looks like Claude does not support pagination yet func WithPaginationLimit(limit int) ServerOptions { return func(s *Server) { @@ -175,6 +187,21 @@ func NewServer(transport transport.Transport, options ...ServerOptions) *Server return server } +func (s *Server) SendLogMessageNotification(level Level, logger string, data interface{}) error { + if !s.isRunning { + return nil + } + s.mu.RLock() + defer s.mu.RUnlock() + if level < s.loggingLevel { + return nil + } + if err := s.protocol.Notification("notifications/message", newLoggingMessageParams(level, logger, data)); err != nil { + return err + } + return nil +} + // RegisterTool registers a new tool with the server func (s *Server) RegisterTool(name string, description string, handler any) error { err := validateToolHandler(handler) @@ -554,6 +581,7 @@ func (s *Server) Serve() error { pr.SetRequestHandler("prompts/get", s.handlePromptCalls) pr.SetRequestHandler("resources/list", s.handleListResources) pr.SetRequestHandler("resources/read", s.handleResourceCalls) + pr.SetRequestHandler("logging/setLevel", s.handleSetLoggingLevel) err := pr.Connect(s.transport) if err != nil { return err @@ -651,6 +679,28 @@ func (s *Server) handleListTools(ctx context.Context, request *transport.BaseJSO }, nil } +// handleSetLoggingLevel sets the logging level for server +func (s *Server) handleSetLoggingLevel(ctx context.Context, request *transport.BaseJSONRPCRequest, _ protocol.RequestHandlerExtra) (transport.JsonRpcBody, error) { + type setLoggingLevelParams struct { + Level Level `json:"level"` + } + var params setLoggingLevelParams + if request.Params == nil { + params = setLoggingLevelParams{} + } else { + if err := json.Unmarshal(request.Params, ¶ms); err != nil { + return nil, errors.Wrap(err, "Invalid params") + } + } + + s.mu.Lock() + defer s.mu.Unlock() + if params.Level != LevelNil { + s.loggingLevel = params.Level + } + return struct{}{}, nil +} + func (s *Server) handleToolCalls(ctx context.Context, req *transport.BaseJSONRPCRequest, _ protocol.RequestHandlerExtra) (transport.JsonRpcBody, error) { params := baseCallToolRequestParams{} // Instantiate a struct of the type of the arguments @@ -691,6 +741,7 @@ func (s *Server) generateCapabilities() ServerCapabilities { ListChanged: &t, } }(), + Logging: s.capabilities.Logging, } } func (s *Server) handleListPrompts(ctx context.Context, request *transport.BaseJSONRPCRequest, extra protocol.RequestHandlerExtra) (transport.JsonRpcBody, error) { diff --git a/transport/types.go b/transport/types.go index d2627f8..fbdfd91 100644 --- a/transport/types.go +++ b/transport/types.go @@ -98,9 +98,10 @@ type BaseJSONRPCNotification struct { // Requires a Jsonrpc and Method func (m *BaseJSONRPCNotification) UnmarshalJSON(data []byte) error { required := struct { - Jsonrpc *string `json:"jsonrpc" yaml:"jsonrpc" mapstructure:"jsonrpc"` - Method *string `json:"method" yaml:"method" mapstructure:"method"` - Id *int64 `json:"id" yaml:"id" mapstructure:"id"` + Jsonrpc *string `json:"jsonrpc" yaml:"jsonrpc" mapstructure:"jsonrpc"` + Method *string `json:"method" yaml:"method" mapstructure:"method"` + Id *int64 `json:"id" yaml:"id" mapstructure:"id"` + Params *json.RawMessage `json:"params" yaml:"params" mapstructure:"params"` }{} err := json.Unmarshal(data, &required) if err != nil { @@ -115,8 +116,12 @@ func (m *BaseJSONRPCNotification) UnmarshalJSON(data []byte) error { if required.Id != nil { return errors.New("field id in BaseJSONRPCNotification: not allowed") } + if required.Params == nil { + required.Params = new(json.RawMessage) + } m.Jsonrpc = *required.Jsonrpc m.Method = *required.Method + m.Params = *required.Params return nil } diff --git a/utilities_api.go b/utilities_api.go new file mode 100644 index 0000000..b27d6fe --- /dev/null +++ b/utilities_api.go @@ -0,0 +1,55 @@ +package mcp_golang + +type Level int + +// level2str maps the level integer to the level string +var level2str = map[Level]string{ + LevelNil: "Nil", + LevelDebug: "Debug", + LevelInfo: "Info", + LevelNotice: "Notice", + LevelWarning: "Warning", + LevelError: "Error", + LevelCritical: "Critical", + LevelAlert: "Alert", + LevelEmergency: "Emergency", +} + +// str2Level maps the level string to the level integer +var str2Level = map[string]Level{ + "Nil": LevelNil, + "Debug": LevelDebug, + "Info": LevelInfo, + "Notice": LevelNotice, + "Warning": LevelWarning, + "Error": LevelError, + "Critical": LevelCritical, + "Alert": LevelAlert, + "Emergency": LevelEmergency, +} + +const ( + LevelNil Level = iota + LevelDebug + LevelInfo + LevelNotice + LevelWarning + LevelError + LevelCritical + LevelAlert + LevelEmergency +) + +type LoggingMessageParams struct { + Level string `json:"level" yaml:"level" mapstructure:"level"` + Logger string `json:"logger" yaml:"logger" mapstructure:"logger"` + Data interface{} `json:"data" yaml:"data" mapstructure:"data"` +} + +func newLoggingMessageParams(level Level, logger string, data interface{}) LoggingMessageParams { + return LoggingMessageParams{ + Level: level2str[level], + Logger: logger, + Data: data, + } +}