Skip to content

Feature/proto logging implementation #95

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,26 @@ type Client struct {
initialized bool
}

type ClientOptions func(*Client)

func WithNotificationHandler(
method string, handler func(notification *transport.BaseJSONRPCNotification) error,
) ClientOptions {
return func(c *Client) {
c.protocol.SetNotificationHandler(method, handler)
}
}

// NewClient creates a new MCP client with the specified transport
func NewClient(transport transport.Transport) *Client {
return &Client{
func NewClient(transport transport.Transport, options ...ClientOptions) *Client {
client := &Client{
transport: transport,
protocol: protocol.NewProtocol(nil),
}
for _, option := range options {
option(client)
}
return client
}

// Initialize connects to the server and retrieves its capabilities
Expand Down Expand Up @@ -248,6 +262,19 @@ func (c *Client) ReadResource(ctx context.Context, uri string) (*ResourceRespons
return resourceResponse.Response, nil
}

func (c *Client) SetLoggingLevel(ctx context.Context, level Level) error {
if !c.initialized {
return errors.New("client not initialized")
}
params := map[string]Level{
"level": level,
}
if _, err := c.protocol.Request(ctx, "logging/setLevel", params, nil); err != nil {
return errors.Wrap(err, "failed to set logging level")
}
return nil
}

// Ping sends a ping request to the server to check connectivity
func (c *Client) Ping(ctx context.Context) error {
if !c.initialized {
Expand Down
76 changes: 76 additions & 0 deletions examples/log_example/stdio/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package main

import (
"context"
"encoding/json"
"fmt"
"log"
"os/exec"
"time"

mcp "github.com/metoro-io/mcp-golang"
"github.com/metoro-io/mcp-golang/transport"
"github.com/metoro-io/mcp-golang/transport/stdio"
)

func main() {
// Start the server process
cmd := exec.Command("go", "run", "./server/main.go")
stdin, err := cmd.StdinPipe()
if err != nil {
log.Fatalf("Failed to get stdin pipe: %v", err)
}
stdout, err := cmd.StdoutPipe()
if err != nil {
log.Fatalf("Failed to get stdout pipe: %v", err)
}

if err := cmd.Start(); err != nil {
log.Fatalf("Failed to start server: %v", err)
}
defer cmd.Process.Kill()

clientTransport := stdio.NewStdioServerTransportWithIO(stdout, stdin)
client := mcp.NewClient(clientTransport, mcp.WithNotificationHandler("notifications/message", func(notification *transport.BaseJSONRPCNotification) error {
var params struct {
Level string `json:"level" yaml:"level" mapstructure:"level"`
Logger string `json:"logger" yaml:"logger" mapstructure:"logger"`
Data interface{} `json:"data" yaml:"data" mapstructure:"data"`
}
if err := json.Unmarshal(notification.Params, &params); err != nil {
log.Println("failed to unmarshal log_example params:", err.Error())
return fmt.Errorf("failed to unmarshal log_example params: %w", err)
}
log.Printf("[%s] Notification: %s", params.Level, params.Data)
return nil
}))

if _, err := client.Initialize(context.Background()); err != nil {
log.Fatalf("Failed to initialize client: %v", err)
}

for _, level := range []mcp.Level{
mcp.LevelDebug,
mcp.LevelInfo,
mcp.LevelNotice,
mcp.LevelWarning,
mcp.LevelError,
mcp.LevelCritical,
mcp.LevelAlert,
mcp.LevelEmergency,
} {
if err := client.SetLoggingLevel(context.Background(), level); err != nil {
log.Fatalf("Failed to set logging level: %v", err)
}
args := map[string]interface{}{
"name": "World",
}
_, err := client.CallTool(context.Background(), "log", args)
if err != nil {
log.Printf("Failed to call log tool: %v", err)
}
// wait all notifications arrive
time.Sleep(3 * time.Second)
log.Println("----------------------------------")
}
}
66 changes: 66 additions & 0 deletions examples/log_example/stdio/server/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package main

import (
"log"
"os"

mcp "github.com/metoro-io/mcp-golang"
"github.com/metoro-io/mcp-golang/transport/stdio"
)

func init() {
logFile, _ := os.OpenFile("./server.log", os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
log.SetOutput(logFile)
}

// HelloArgs ...
type HelloArgs struct {
Name string `json:"name" jsonschema:"required,description=The name to say hello to"`
}

func main() {
// Create a transport for the server
serverTransport := stdio.NewStdioServerTransport()

// Create a new server with the transport
server := mcp.NewServer(serverTransport, mcp.WithLoggingCapability())

if err := server.RegisterTool("log", "get some log", func(k HelloArgs) (*mcp.ToolResponse, error) {
if err := server.SendLogMessageNotification(mcp.LevelDebug, "server", "debug"); err != nil {
log.Panic(err)
}
if err := server.SendLogMessageNotification(mcp.LevelInfo, "server", "info"); err != nil {
log.Panic(err)
}
if err := server.SendLogMessageNotification(mcp.LevelNotice, "server", "notice"); err != nil {
log.Panic(err)
}
if err := server.SendLogMessageNotification(mcp.LevelWarning, "server", "warning"); err != nil {
log.Panic(err)
}
if err := server.SendLogMessageNotification(mcp.LevelError, "server", "error"); err != nil {
log.Panic(err)
}
if err := server.SendLogMessageNotification(mcp.LevelCritical, "server", "critical"); err != nil {
log.Panic(err)
}
if err := server.SendLogMessageNotification(mcp.LevelAlert, "server", "alert"); err != nil {
log.Panic(err)
}
if err := server.SendLogMessageNotification(mcp.LevelEmergency, "server", "emergency"); err != nil {
log.Panic(err)
}
return &mcp.ToolResponse{}, nil
}); err != nil {
log.Panic(err)
}

// Start the server
if err := server.Serve(); err != nil {
log.Printf("failed to serve: %v\n", err)
panic(err)
}
log.Println("server running")
// Keep the server running
select {}
}
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ require (
github.com/tidwall/pretty v1.2.1 // indirect
github.com/ugorji/go/codec v1.2.7 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
go.uber.org/multierr v1.10.0 // indirect
go.uber.org/zap v1.27.0 // indirect
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 // indirect
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 // indirect
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
Expand Down
16 changes: 16 additions & 0 deletions internal/protocol/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"sync"
"time"

Expand Down Expand Up @@ -158,6 +159,7 @@ func NewProtocol(options *ProtocolOptions) *Protocol {

// Set up default handlers
p.SetNotificationHandler("notifications/cancelled", p.handleCancelledNotification)
p.SetNotificationHandler("notifications/message", p.messageNotificationHandler)
p.SetNotificationHandler("$/progress", p.handleProgressNotification)

return p
Expand Down Expand Up @@ -297,6 +299,20 @@ func (p *Protocol) handleRequest(ctx context.Context, request *transport.BaseJSO
}()
}

func (p *Protocol) messageNotificationHandler(notification *transport.BaseJSONRPCNotification) error {
var params struct {
Level string `json:"level" yaml:"level" mapstructure:"level"`
Logger string `json:"logger" yaml:"logger" mapstructure:"logger"`
Data interface{} `json:"data" yaml:"data" mapstructure:"data"`
}
if err := json.Unmarshal(notification.Params, &params); err != nil {
log.Println("failed to unmarshal log_example params:", err.Error())
return fmt.Errorf("failed to unmarshal log_example params: %w", err)
}
log.Printf("[%s] %s\n", params.Level, params.Data.(string))
return nil
}

func (p *Protocol) handleProgressNotification(notification *transport.BaseJSONRPCNotification) error {
var params struct {
Progress int64 `json:"progress"`
Expand Down
51 changes: 51 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"reflect"
"sort"
"strings"
"sync"

"github.com/invopop/jsonschema"
"github.com/metoro-io/mcp-golang/internal/datastructures"
Expand Down Expand Up @@ -100,6 +101,7 @@ func (c promptResponseSent) MarshalJSON() ([]byte, error) {
}

type Server struct {
mu sync.RWMutex
isRunning bool
transport transport.Transport
protocol *protocol.Protocol
Expand All @@ -110,6 +112,8 @@ type Server struct {
serverInstructions *string
serverName string
serverVersion string
capabilities ServerCapabilities
loggingLevel Level
}

type prompt struct {
Expand Down Expand Up @@ -142,6 +146,14 @@ func WithProtocol(protocol *protocol.Protocol) ServerOptions {
}
}

func WithLoggingCapability() ServerOptions {
return func(s *Server) {
s.capabilities.Logging = map[string]interface{}{
"enable": true,
}
}
}

// Beware: As of 2024-12-13, it looks like Claude does not support pagination yet
func WithPaginationLimit(limit int) ServerOptions {
return func(s *Server) {
Expand Down Expand Up @@ -175,6 +187,21 @@ func NewServer(transport transport.Transport, options ...ServerOptions) *Server
return server
}

func (s *Server) SendLogMessageNotification(level Level, logger string, data interface{}) error {
if !s.isRunning {
return nil
}
s.mu.RLock()
defer s.mu.RUnlock()
if level < s.loggingLevel {
return nil
}
if err := s.protocol.Notification("notifications/message", newLoggingMessageParams(level, logger, data)); err != nil {
return err
}
return nil
}

// RegisterTool registers a new tool with the server
func (s *Server) RegisterTool(name string, description string, handler any) error {
err := validateToolHandler(handler)
Expand Down Expand Up @@ -554,6 +581,7 @@ func (s *Server) Serve() error {
pr.SetRequestHandler("prompts/get", s.handlePromptCalls)
pr.SetRequestHandler("resources/list", s.handleListResources)
pr.SetRequestHandler("resources/read", s.handleResourceCalls)
pr.SetRequestHandler("logging/setLevel", s.handleSetLoggingLevel)
err := pr.Connect(s.transport)
if err != nil {
return err
Expand Down Expand Up @@ -651,6 +679,28 @@ func (s *Server) handleListTools(ctx context.Context, request *transport.BaseJSO
}, nil
}

// handleSetLoggingLevel sets the logging level for server
func (s *Server) handleSetLoggingLevel(ctx context.Context, request *transport.BaseJSONRPCRequest, _ protocol.RequestHandlerExtra) (transport.JsonRpcBody, error) {
type setLoggingLevelParams struct {
Level Level `json:"level"`
}
var params setLoggingLevelParams
if request.Params == nil {
params = setLoggingLevelParams{}
} else {
if err := json.Unmarshal(request.Params, &params); err != nil {
return nil, errors.Wrap(err, "Invalid params")
}
}

s.mu.Lock()
defer s.mu.Unlock()
if params.Level != LevelNil {
s.loggingLevel = params.Level
}
return struct{}{}, nil
}

func (s *Server) handleToolCalls(ctx context.Context, req *transport.BaseJSONRPCRequest, _ protocol.RequestHandlerExtra) (transport.JsonRpcBody, error) {
params := baseCallToolRequestParams{}
// Instantiate a struct of the type of the arguments
Expand Down Expand Up @@ -691,6 +741,7 @@ func (s *Server) generateCapabilities() ServerCapabilities {
ListChanged: &t,
}
}(),
Logging: s.capabilities.Logging,
}
}
func (s *Server) handleListPrompts(ctx context.Context, request *transport.BaseJSONRPCRequest, extra protocol.RequestHandlerExtra) (transport.JsonRpcBody, error) {
Expand Down
11 changes: 8 additions & 3 deletions transport/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,10 @@ type BaseJSONRPCNotification struct {
// Requires a Jsonrpc and Method
func (m *BaseJSONRPCNotification) UnmarshalJSON(data []byte) error {
required := struct {
Jsonrpc *string `json:"jsonrpc" yaml:"jsonrpc" mapstructure:"jsonrpc"`
Method *string `json:"method" yaml:"method" mapstructure:"method"`
Id *int64 `json:"id" yaml:"id" mapstructure:"id"`
Jsonrpc *string `json:"jsonrpc" yaml:"jsonrpc" mapstructure:"jsonrpc"`
Method *string `json:"method" yaml:"method" mapstructure:"method"`
Id *int64 `json:"id" yaml:"id" mapstructure:"id"`
Params *json.RawMessage `json:"params" yaml:"params" mapstructure:"params"`
}{}
err := json.Unmarshal(data, &required)
if err != nil {
Expand All @@ -115,8 +116,12 @@ func (m *BaseJSONRPCNotification) UnmarshalJSON(data []byte) error {
if required.Id != nil {
return errors.New("field id in BaseJSONRPCNotification: not allowed")
}
if required.Params == nil {
required.Params = new(json.RawMessage)
}
m.Jsonrpc = *required.Jsonrpc
m.Method = *required.Method
m.Params = *required.Params
return nil
}

Expand Down
Loading