diff --git a/README.md b/README.md index 74f4ec5..9590d80 100644 --- a/README.md +++ b/README.md @@ -4,16 +4,16 @@
-![GitHub stars](https://img.shields.io/github/stars/metoro-io/mcp-golang?style=social) -![GitHub forks](https://img.shields.io/github/forks/metoro-io/mcp-golang?style=social) -![GitHub issues](https://img.shields.io/github/issues/metoro-io/mcp-golang) -![GitHub pull requests](https://img.shields.io/github/issues-pr/metoro-io/mcp-golang) -![GitHub license](https://img.shields.io/github/license/metoro-io/mcp-golang) -![GitHub contributors](https://img.shields.io/github/contributors/metoro-io/mcp-golang) -![GitHub last commit](https://img.shields.io/github/last-commit/metoro-io/mcp-golang) -[![GoDoc](https://pkg.go.dev/badge/github.com/metoro-io/mcp-golang.svg)](https://pkg.go.dev/github.com/metoro-io/mcp-golang) -[![Go Report Card](https://goreportcard.com/badge/github.com/metoro-io/mcp-golang)](https://goreportcard.com/report/github.com/metoro-io/mcp-golang) -![Tests](https://github.com/metoro-io/mcp-golang/actions/workflows/go-test.yml/badge.svg) +![GitHub stars](https://img.shields.io/github/stars/rvoh-emccaleb/mcp-golang?style=social) +![GitHub forks](https://img.shields.io/github/forks/rvoh-emccaleb/mcp-golang?style=social) +![GitHub issues](https://img.shields.io/github/issues/rvoh-emccaleb/mcp-golang) +![GitHub pull requests](https://img.shields.io/github/issues-pr/rvoh-emccaleb/mcp-golang) +![GitHub license](https://img.shields.io/github/license/rvoh-emccaleb/mcp-golang) +![GitHub contributors](https://img.shields.io/github/contributors/rvoh-emccaleb/mcp-golang) +![GitHub last commit](https://img.shields.io/github/last-commit/rvoh-emccaleb/mcp-golang) +[![GoDoc](https://pkg.go.dev/badge/github.com/rvoh-emccaleb/mcp-golang.svg)](https://pkg.go.dev/github.com/rvoh-emccaleb/mcp-golang) +[![Go Report Card](https://goreportcard.com/badge/github.com/rvoh-emccaleb/mcp-golang)](https://goreportcard.com/report/github.com/rvoh-emccaleb/mcp-golang) +![Tests](https://github.com/rvoh-emccaleb/mcp-golang/actions/workflows/go-test.yml/badge.svg) @@ -37,7 +37,7 @@ Docs at [https://mcpgolang.com](https://mcpgolang.com) ## Example Usage -Install with `go get github.com/metoro-io/mcp-golang` +Install with `go get github.com/rvoh-emccaleb/mcp-golang` ### Server Example @@ -46,8 +46,8 @@ package main import ( "fmt" - "github.com/metoro-io/mcp-golang" - "github.com/metoro-io/mcp-golang/transport/stdio" + "github.com/rvoh-emccaleb/mcp-golang" + "github.com/rvoh-emccaleb/mcp-golang/transport/stdio" ) // Tool arguments are just structs, annotated with jsonschema tags @@ -121,8 +121,8 @@ package main import ( "context" "log" - mcp "github.com/metoro-io/mcp-golang" - "github.com/metoro-io/mcp-golang/transport/stdio" + mcp "github.com/rvoh-emccaleb/mcp-golang" + "github.com/rvoh-emccaleb/mcp-golang/transport/stdio" ) // Define type-safe arguments diff --git a/client.go b/client.go index 672e217..5ea27e0 100644 --- a/client.go +++ b/client.go @@ -4,9 +4,9 @@ import ( "context" "encoding/json" - "github.com/metoro-io/mcp-golang/internal/protocol" - "github.com/metoro-io/mcp-golang/transport" "github.com/pkg/errors" + "github.com/rvoh-emccaleb/mcp-golang/internal/protocol" + "github.com/rvoh-emccaleb/mcp-golang/transport" ) // Client represents an MCP client that can connect to and interact with MCP servers @@ -25,8 +25,21 @@ func NewClient(transport transport.Transport) *Client { } } +// A bit loosey goosey, but it works for now. +// See: https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/ +type InitializeRequestParams struct { + ProtocolVersion string `json:"protocolVersion"` + ClientInfo InitializeRequestClientInfo `json:"clientInfo"` + Capabilities map[string]interface{} `json:"capabilities"` +} + +type InitializeRequestClientInfo struct { + Name string `json:"name"` + Version string `json:"version"` +} + // Initialize connects to the server and retrieves its capabilities -func (c *Client) Initialize(ctx context.Context) (*InitializeResponse, error) { +func (c *Client) Initialize(ctx context.Context, params *InitializeRequestParams) (*InitializeResponse, error) { if c.initialized { return nil, errors.New("client already initialized") } @@ -36,8 +49,8 @@ func (c *Client) Initialize(ctx context.Context) (*InitializeResponse, error) { return nil, errors.Wrap(err, "failed to connect transport") } - // Make initialize request to server - response, err := c.protocol.Request(ctx, "initialize", map[string]interface{}{}, nil) + // Begin initialization handshake with server + response, err := c.protocol.Request(ctx, "initialize", params, nil) if err != nil { return nil, errors.Wrap(err, "failed to initialize") } @@ -53,8 +66,15 @@ func (c *Client) Initialize(ctx context.Context) (*InitializeResponse, error) { return nil, errors.Wrap(err, "failed to unmarshal initialize response") } + // Finish initialization handshake with server + err = c.protocol.Notification("notifications/initialized", map[string]interface{}{}) + if err != nil { + return nil, errors.Wrap(err, "failed to send initialized notification") + } + c.capabilities = &initResult.Capabilities c.initialized = true + return &initResult, nil } @@ -235,17 +255,13 @@ func (c *Client) ReadResource(ctx context.Context, uri string) (*ResourceRespons return nil, errors.New("invalid response type") } - var resourceResponse resourceResponseSent + var resourceResponse ResourceResponse err = json.Unmarshal(responseBytes, &resourceResponse) if err != nil { return nil, errors.Wrap(err, "failed to unmarshal resource response") } - if resourceResponse.Error != nil { - return nil, resourceResponse.Error - } - - return resourceResponse.Response, nil + return &resourceResponse, nil } // Ping sends a ping request to the server to check connectivity @@ -266,3 +282,13 @@ func (c *Client) Ping(ctx context.Context) error { func (c *Client) GetCapabilities() *ServerCapabilities { return c.capabilities } + +// Close cleans up resources used by the client, including the protocol and transport layers. +// It should be called when the client is no longer needed. +func (c *Client) Close() error { + if err := c.protocol.Close(); err != nil { + return errors.Wrap(err, "failed to close protocol") + } + + return nil +} diff --git a/content_api.go b/content_api.go index b111593..94e9989 100644 --- a/content_api.go +++ b/content_api.go @@ -85,14 +85,51 @@ type EmbeddedResource struct { // Custom JSON marshaling for EmbeddedResource func (c EmbeddedResource) MarshalJSON() ([]byte, error) { + type wrapper struct { + Type string `json:"embeddedResourceType"` + *TextResourceContents + *BlobResourceContents + } + + w := wrapper{Type: string(c.EmbeddedResourceType)} + switch c.EmbeddedResourceType { case embeddedResourceTypeBlob: - return json.Marshal(c.BlobResourceContents) + w.BlobResourceContents = c.BlobResourceContents case embeddedResourceTypeText: - return json.Marshal(c.TextResourceContents) + w.TextResourceContents = c.TextResourceContents default: return nil, fmt.Errorf("unknown embedded resource type: %s", c.EmbeddedResourceType) } + + return json.Marshal(w) +} + +func (c *EmbeddedResource) UnmarshalJSON(data []byte) error { + var wrapper struct { + Type string `json:"embeddedResourceType"` + Raw json.RawMessage `json:"-"` + } + + wrapper.Raw = data + if err := json.Unmarshal(data, &wrapper); err != nil { + return err + } + + c.EmbeddedResourceType = embeddedResourceType(wrapper.Type) + + switch c.EmbeddedResourceType { + case embeddedResourceTypeText: + c.TextResourceContents = new(TextResourceContents) + return json.Unmarshal(wrapper.Raw, c.TextResourceContents) + + case embeddedResourceTypeBlob: + c.BlobResourceContents = new(BlobResourceContents) + return json.Unmarshal(wrapper.Raw, c.BlobResourceContents) + + default: + return fmt.Errorf("unknown embedded resource type: %s", wrapper.Type) + } } type ContentType string diff --git a/docs/change-notifications.mdx b/docs/change-notifications.mdx index 82fd09d..5c9d8bd 100644 --- a/docs/change-notifications.mdx +++ b/docs/change-notifications.mdx @@ -24,8 +24,8 @@ package main import ( "fmt" - mcp_golang "github.com/metoro-io/mcp-golang" - "github.com/metoro-io/mcp-golang/transport/stdio" + mcp_golang "github.com/rvoh-emccaleb/mcp-golang" + "github.com/rvoh-emccaleb/mcp-golang/transport/stdio" "time" ) diff --git a/docs/client.mdx b/docs/client.mdx index 1a175c1..6e45bb9 100644 --- a/docs/client.mdx +++ b/docs/client.mdx @@ -12,7 +12,7 @@ The MCP client provides a simple and intuitive way to interact with MCP servers. Add the MCP Golang package to your project: ```bash -go get github.com/metoro-io/mcp-golang +go get github.com/rvoh-emccaleb/mcp-golang ``` ## Basic Usage @@ -21,8 +21,8 @@ Here's a simple example of creating and initializing an MCP client: ```go import ( - mcp "github.com/metoro-io/mcp-golang" - "github.com/metoro-io/mcp-golang/transport/stdio" + mcp "github.com/rvoh-emccaleb/mcp-golang" + "github.com/rvoh-emccaleb/mcp-golang/transport/stdio" ) // Create a transport (stdio in this example) @@ -197,7 +197,7 @@ if err != nil { ## Complete Example -For a complete working example, check out our [example client implementation](https://github.com/metoro-io/mcp-golang/tree/main/examples/client). +For a complete working example, check out our [example client implementation](https://github.com/rvoh-emccaleb/mcp-golang/tree/main/examples/client). ## Transport Options diff --git a/docs/contributing.mdx b/docs/contributing.mdx index 0bbc722..425b44c 100644 --- a/docs/contributing.mdx +++ b/docs/contributing.mdx @@ -19,7 +19,7 @@ To set up your development environment, follow these steps: 1. Clone the repository: ```bash -git clone https://github.com/metoro-io/mcp-golang.git +git clone https://github.com/rvoh-emccaleb/mcp-golang.git cd mcp-golang ``` @@ -210,6 +210,6 @@ When your PR merges into the main branch, it will be deployed automatically. ## Getting Help -- Check existing [GitHub issues](https://github.com/metoro-io/mcp-golang/issues) +- Check existing [GitHub issues](https://github.com/rvoh-emccaleb/mcp-golang/issues) - Join our [Discord community](https://discord.gg/33saRwE3pT) - Read the [Model Context Protocol specification](https://modelcontextprotocol.io/) diff --git a/docs/mint.json b/docs/mint.json index 67c435c..7947bfa 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -21,13 +21,13 @@ }, "topbarCtaButton": { "name": "Github Repo", - "url": "https://github.com/metoro-io/mcp-golang" + "url": "https://github.com/rvoh-emccaleb/mcp-golang" }, "anchors": [ { "name": "Github", "icon": "github", - "url": "https://github.com/metoro-io/mcp-golang" + "url": "https://github.com/rvoh-emccaleb/mcp-golang" }, { "name": "Discord Community", @@ -62,7 +62,7 @@ ], "footerSocials": { "x": "https://x.com/metoro_ai", - "github": "https://github.com/metoro-io/mcp-golang", + "github": "https://github.com/rvoh-emccaleb/mcp-golang", "website": "https://https://metoro.io/" } } diff --git a/docs/quickstart.mdx b/docs/quickstart.mdx index 03bdb40..a5148e3 100644 --- a/docs/quickstart.mdx +++ b/docs/quickstart.mdx @@ -8,7 +8,7 @@ description: 'Set up your first mcp-golang server' First, add mcp-golang to your project: ```bash -go get github.com/metoro-io/mcp-golang +go get github.com/rvoh-emccaleb/mcp-golang ``` ## Basic Usage @@ -20,8 +20,8 @@ package main import ( "fmt" - "github.com/metoro-io/mcp-golang" - "github.com/metoro-io/mcp-golang/transport/stdio" + "github.com/rvoh-emccaleb/mcp-golang" + "github.com/rvoh-emccaleb/mcp-golang/transport/stdio" ) type Content struct { @@ -93,8 +93,8 @@ package main import ( "context" "log" - "github.com/metoro-io/mcp-golang" - "github.com/metoro-io/mcp-golang/transport/http" + "github.com/rvoh-emccaleb/mcp-golang" + "github.com/rvoh-emccaleb/mcp-golang/transport/http" ) func main() { @@ -122,8 +122,8 @@ package main import ( "github.com/gin-gonic/gin" - "github.com/metoro-io/mcp-golang" - "github.com/metoro-io/mcp-golang/transport/http" + "github.com/rvoh-emccaleb/mcp-golang" + "github.com/rvoh-emccaleb/mcp-golang/transport/http" ) func main() { @@ -155,8 +155,8 @@ package main import ( "context" "log" - "github.com/metoro-io/mcp-golang" - "github.com/metoro-io/mcp-golang/transport/http" + "github.com/rvoh-emccaleb/mcp-golang" + "github.com/rvoh-emccaleb/mcp-golang/transport/http" ) func main() { @@ -183,4 +183,4 @@ func main() { - If you're interested in contributing to mcp-golang, check out [Development Guide](/development) for more detailed information - Join our [Discord Community](https://discord.gg/33saRwE3pT) for support -- Visit our [GitHub Repository](https://github.com/metoro-io/mcp-golang) to contribute +- Visit our [GitHub Repository](https://github.com/rvoh-emccaleb/mcp-golang) to contribute diff --git a/docs/tools.mdx b/docs/tools.mdx index 3854884..328f5f6 100644 --- a/docs/tools.mdx +++ b/docs/tools.mdx @@ -17,8 +17,8 @@ package main import ( "fmt" - "github.com/metoro-io/mcp-golang" - "github.com/metoro-io/mcp-golang/transport/stdio" + "github.com/rvoh-emccaleb/mcp-golang" + "github.com/rvoh-emccaleb/mcp-golang/transport/stdio" ) type HelloArguments struct { diff --git a/examples/basic_tool_server/basic_tool_server.go b/examples/basic_tool_server/basic_tool_server.go index b93e297..9a5a5e7 100644 --- a/examples/basic_tool_server/basic_tool_server.go +++ b/examples/basic_tool_server/basic_tool_server.go @@ -2,8 +2,8 @@ package main import ( "fmt" - "github.com/metoro-io/mcp-golang" - "github.com/metoro-io/mcp-golang/transport/stdio" + "github.com/rvoh-emccaleb/mcp-golang" + "github.com/rvoh-emccaleb/mcp-golang/transport/stdio" ) type Content struct { diff --git a/examples/client/main.go b/examples/client/main.go index e37efd9..375217c 100644 --- a/examples/client/main.go +++ b/examples/client/main.go @@ -6,8 +6,12 @@ import ( "context" - mcp_golang "github.com/metoro-io/mcp-golang" - "github.com/metoro-io/mcp-golang/transport/stdio" + mcp_golang "github.com/rvoh-emccaleb/mcp-golang" + "github.com/rvoh-emccaleb/mcp-golang/transport/stdio" +) + +const ( + ProtocolVersion = "2024-11-05" ) func main() { @@ -17,6 +21,7 @@ func main() { if err != nil { log.Fatalf("Failed to get stdin pipe: %v", err) } + stdout, err := cmd.StdoutPipe() if err != nil { log.Fatalf("Failed to get stdout pipe: %v", err) @@ -25,12 +30,23 @@ func main() { if err := cmd.Start(); err != nil { log.Fatalf("Failed to start server: %v", err) } + defer cmd.Process.Kill() clientTransport := stdio.NewStdioServerTransportWithIO(stdout, stdin) client := mcp_golang.NewClient(clientTransport) - if _, err := client.Initialize(context.Background()); err != nil { + if _, err := client.Initialize( + context.Background(), + &mcp_golang.InitializeRequestParams{ + ClientInfo: mcp_golang.InitializeRequestClientInfo{ + Name: "example-client", + Version: "0.1.0", + }, + ProtocolVersion: ProtocolVersion, + Capabilities: nil, + }, + ); err != nil { log.Fatalf("Failed to initialize client: %v", err) } diff --git a/examples/client/server/main.go b/examples/client/server/main.go index 87caa8b..fbf7b71 100644 --- a/examples/client/server/main.go +++ b/examples/client/server/main.go @@ -5,8 +5,8 @@ import ( "strings" "time" - mcp "github.com/metoro-io/mcp-golang" - "github.com/metoro-io/mcp-golang/transport/stdio" + mcp "github.com/rvoh-emccaleb/mcp-golang" + "github.com/rvoh-emccaleb/mcp-golang/transport/stdio" ) // HelloArgs represents the arguments for the hello tool diff --git a/examples/get_weather_tool_server/get_weather_tool_server.go b/examples/get_weather_tool_server/get_weather_tool_server.go index 096c542..bdbf952 100644 --- a/examples/get_weather_tool_server/get_weather_tool_server.go +++ b/examples/get_weather_tool_server/get_weather_tool_server.go @@ -2,8 +2,8 @@ package main import ( "fmt" - mcp_golang "github.com/metoro-io/mcp-golang" - "github.com/metoro-io/mcp-golang/transport/stdio" + mcp_golang "github.com/rvoh-emccaleb/mcp-golang" + "github.com/rvoh-emccaleb/mcp-golang/transport/stdio" "io" "net/http" ) diff --git a/examples/gin_example/main.go b/examples/gin_example/main.go index 2753592..c332d0e 100644 --- a/examples/gin_example/main.go +++ b/examples/gin_example/main.go @@ -7,8 +7,8 @@ import ( "time" "github.com/gin-gonic/gin" - mcp_golang "github.com/metoro-io/mcp-golang" - "github.com/metoro-io/mcp-golang/transport/http" + mcp_golang "github.com/rvoh-emccaleb/mcp-golang" + "github.com/rvoh-emccaleb/mcp-golang/transport/http" ) // TimeArgs defines the arguments for the time tool diff --git a/examples/http_example/auth_example_client/main.go b/examples/http_example/auth_example_client/main.go index 71e5830..1de55eb 100644 --- a/examples/http_example/auth_example_client/main.go +++ b/examples/http_example/auth_example_client/main.go @@ -3,15 +3,20 @@ package main import ( "context" "log" + "time" "github.com/davecgh/go-spew/spew" - mcp_golang "github.com/metoro-io/mcp-golang" - "github.com/metoro-io/mcp-golang/transport/http" + mcp_golang "github.com/rvoh-emccaleb/mcp-golang" + "github.com/rvoh-emccaleb/mcp-golang/transport/http" +) + +const ( + ProtocolVersion = "2024-11-05" ) func main() { // Create an HTTP transport that connects to the server - transport := http.NewHTTPClientTransport("/mcp") + transport := http.NewHTTPClientTransport("/mcp", 1*time.Millisecond) transport.WithBaseURL("http://localhost:8080/api/v1") // Public metoro token - not a leak transport.WithHeader("Authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJjdXN0b21lcklkIjoiOThlZDU1M2QtYzY4ZC00MDRhLWFhZjItNDM2ODllNWJiMGUzIiwiZW1haWwiOiJ0ZXN0QGNocmlzYmF0dGFyYmVlLmNvbSIsImV4cCI6MTgyMTI0NzIzN30.QeFzKsP1yO16pVol0mkAdt7qhJf6nTqBoqXqdWawBdE") @@ -20,7 +25,17 @@ func main() { client := mcp_golang.NewClient(transport) // Initialize the client - if resp, err := client.Initialize(context.Background()); err != nil { + if resp, err := client.Initialize( + context.Background(), + &mcp_golang.InitializeRequestParams{ + ClientInfo: mcp_golang.InitializeRequestClientInfo{ + Name: "example-client", + Version: "0.1.0", + }, + ProtocolVersion: ProtocolVersion, + Capabilities: nil, + }, + ); err != nil { log.Fatalf("Failed to initialize client: %v", err) } else { log.Printf("Initialized client: %v", spew.Sdump(resp)) diff --git a/examples/http_example/client/main.go b/examples/http_example/client/main.go index 622bfdd..fb307a1 100644 --- a/examples/http_example/client/main.go +++ b/examples/http_example/client/main.go @@ -6,21 +6,36 @@ import ( "time" "github.com/davecgh/go-spew/spew" - mcp_golang "github.com/metoro-io/mcp-golang" - "github.com/metoro-io/mcp-golang/transport/http" + mcp_golang "github.com/rvoh-emccaleb/mcp-golang" + "github.com/rvoh-emccaleb/mcp-golang/transport/http" +) + +const ( + ProtocolVersion = "2024-11-05" ) func main() { // Create an HTTP transport that connects to the server - transport := http.NewHTTPClientTransport("/mcp") + transport := http.NewHTTPClientTransport("/mcp", 1*time.Millisecond) transport.WithBaseURL("http://localhost:8081") // Create a new client with the transport client := mcp_golang.NewClient(transport) // Initialize the client - if resp, err := client.Initialize(context.Background()); err != nil { + if resp, err := client.Initialize( + context.Background(), + &mcp_golang.InitializeRequestParams{ + ClientInfo: mcp_golang.InitializeRequestClientInfo{ + Name: "example-client", + Version: "0.1.0", + }, + ProtocolVersion: ProtocolVersion, + Capabilities: nil, + }, + ); err != nil { log.Fatalf("Failed to initialize client: %v", err) + } else { log.Printf("Initialized client: %v", spew.Sdump(resp)) } diff --git a/examples/http_example/server/main.go b/examples/http_example/server/main.go index 563180c..d817a57 100644 --- a/examples/http_example/server/main.go +++ b/examples/http_example/server/main.go @@ -4,8 +4,8 @@ import ( "log" "time" - mcp_golang "github.com/metoro-io/mcp-golang" - "github.com/metoro-io/mcp-golang/transport/http" + mcp_golang "github.com/rvoh-emccaleb/mcp-golang" + "github.com/rvoh-emccaleb/mcp-golang/transport/http" ) // TimeArgs defines the arguments for the time tool diff --git a/examples/pagination_example/pagination_example.go b/examples/pagination_example/pagination_example.go index 4087929..cba6d19 100644 --- a/examples/pagination_example/pagination_example.go +++ b/examples/pagination_example/pagination_example.go @@ -2,8 +2,8 @@ package main import ( "fmt" - mcp_golang "github.com/metoro-io/mcp-golang" - "github.com/metoro-io/mcp-golang/transport/stdio" + mcp_golang "github.com/rvoh-emccaleb/mcp-golang" + "github.com/rvoh-emccaleb/mcp-golang/transport/stdio" ) // Arguments for our tools diff --git a/examples/readme_server/readme_server.go b/examples/readme_server/readme_server.go index 7e20743..741b34d 100644 --- a/examples/readme_server/readme_server.go +++ b/examples/readme_server/readme_server.go @@ -2,8 +2,8 @@ package main import ( "fmt" - "github.com/metoro-io/mcp-golang" - "github.com/metoro-io/mcp-golang/transport/stdio" + "github.com/rvoh-emccaleb/mcp-golang" + "github.com/rvoh-emccaleb/mcp-golang/transport/stdio" ) // Tool arguments are just structs, annotated with jsonschema tags diff --git a/examples/server/main.go b/examples/server/main.go index c12522b..20504b4 100644 --- a/examples/server/main.go +++ b/examples/server/main.go @@ -3,8 +3,8 @@ package main import ( "fmt" - mcp "github.com/metoro-io/mcp-golang" - "github.com/metoro-io/mcp-golang/transport/stdio" + mcp "github.com/rvoh-emccaleb/mcp-golang" + "github.com/rvoh-emccaleb/mcp-golang/transport/stdio" ) // HelloArgs represents the arguments for the hello tool diff --git a/examples/simple_tool_docs/simple_tool_docs.go b/examples/simple_tool_docs/simple_tool_docs.go index 881aec4..ebb2c2b 100644 --- a/examples/simple_tool_docs/simple_tool_docs.go +++ b/examples/simple_tool_docs/simple_tool_docs.go @@ -2,8 +2,8 @@ package main import ( "fmt" - mcp_golang "github.com/metoro-io/mcp-golang" - "github.com/metoro-io/mcp-golang/transport/stdio" + mcp_golang "github.com/rvoh-emccaleb/mcp-golang" + "github.com/rvoh-emccaleb/mcp-golang/transport/stdio" ) type HelloArguments struct { diff --git a/examples/sse_example/main.go b/examples/sse_example/main.go new file mode 100644 index 0000000..7401956 --- /dev/null +++ b/examples/sse_example/main.go @@ -0,0 +1,253 @@ +package main + +import ( + "bytes" + "context" + "fmt" + "io" + "log" + "net/http" + "strconv" + "strings" + "time" + + "github.com/davecgh/go-spew/spew" + mcp_golang "github.com/rvoh-emccaleb/mcp-golang" + "github.com/rvoh-emccaleb/mcp-golang/transport/sse" +) + +// TimeArgs defines the arguments for the time tool +type TimeArgs struct { + Format string `json:"format" jsonschema:"description=The time format to use"` +} + +const ( + baseEndpoint = "/mcp/sse" + protocolVersion = "2024-11-05" + serverPort = 8083 +) + +func main() { + // Add a root context that we can cancel + rootCtx, rootCancel := context.WithCancel(context.Background()) + defer rootCancel() + + sseTransport := sse.NewServerTransport(baseEndpoint) + + server := mcp_golang.NewServer( + sseTransport, + mcp_golang.WithName("mcp-golang-sse-example"), + mcp_golang.WithVersion("0.0.1"), + ) + + // Register a simple tool + err := server.RegisterTool("time", "Returns the current time in the specified format", func(ctx context.Context, args TimeArgs) (*mcp_golang.ToolResponse, error) { + format := args.Format + if format == "" { + format = time.RFC3339 + } + return mcp_golang.NewToolResponse(mcp_golang.NewTextContent(time.Now().Format(format))), nil + }) + if err != nil { + panic(err) + } + + // Handler for establishing a new SSE connection + http.HandleFunc(baseEndpoint, func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Create a context that's cancelled either by client disconnect or server shutdown + ctx, cancel := context.WithCancel(r.Context()) + go func() { + select { + case <-rootCtx.Done(): + cancel() + case <-r.Context().Done(): + cancel() + } + }() + defer cancel() + + connID, err := sseTransport.HandleSSEConnInitialize(w) + if err != nil { + log.Printf("error initializing sse connection: %v", err) + return + } + + log.Printf("New SSE connection established with ID: %d", connID) + start := time.Now() + + <-ctx.Done() // Wait for either client disconnect or server shutdown + + log.Printf("SSE connection closed after %v for connection with ID: %d", time.Since(start), connID) + sseTransport.RemoveSSEConnection(connID) + }) + + // Handler for receiving JSON-RPC messages + http.HandleFunc(baseEndpoint+"/", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + log.Printf("Received POST request to %s", r.URL.Path) + + // Extract connection ID from path + parts := strings.Split(r.URL.Path, "/") + if len(parts) < 3 { + log.Printf("Error: missing connection ID in path") + http.Error(w, "missing connection ID in path", http.StatusBadRequest) + return + } + connIDStr := parts[len(parts)-1] + + connID, err := strconv.ParseInt(connIDStr, 10, 64) + if err != nil { + log.Printf("Error: invalid connection ID format: %s", connIDStr) + http.Error(w, fmt.Sprintf("invalid connection ID format: %s", connIDStr), http.StatusBadRequest) + return + } + + // Read and log the request body + body, err := io.ReadAll(r.Body) + if err != nil { + log.Printf("Error reading request body: %v", err) + http.Error(w, "failed to read request body", http.StatusBadRequest) + return + } + log.Printf("Request body: %s", string(body)) + r.Body = io.NopCloser(bytes.NewBuffer(body)) + + // Handle all MCP messages (e.g. initialize notification, list tools request, etc.) + if err := sseTransport.HandleMCPMessage(w, r, connID); err != nil { + log.Printf("error handling MCP message: %v", err) + return + } + }) + + // Sets the MCP protocol message handlers on our transport + // e.g. initialize, tools/list, tools/call, etc. + err = server.Serve() + if err != nil { + log.Fatalf("Server error: %v", err) + } + + // Start the HTTP server + log.Printf("Starting HTTP server with SSE on :%d...", serverPort) + httpServer := &http.Server{ + Addr: fmt.Sprintf(":%d", serverPort), + Handler: nil, + + // Timeouts + ReadTimeout: 10 * time.Second, // For initial connection and POST requests + WriteTimeout: 0, // No timeout for SSE writes + MaxHeaderBytes: 1 << 20, // 1MB + ReadHeaderTimeout: 10 * time.Second, // For initial connection and POST requests + IdleTimeout: 0, // No timeout for SSE connections + } + + done := make(chan struct{}) + + go func() { + err := httpServer.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + log.Printf("Server error: %v", err) + } + + close(done) + }() + + time.Sleep(1 * time.Second) + + // Create an HTTP transport that connects to the server + transport := sse.NewSSEClientTransport(baseEndpoint, 1*time.Millisecond) + transport.WithBaseURL(fmt.Sprintf("http://localhost:%d", serverPort)) + + // Create a new client with the transport + client := mcp_golang.NewClient(transport) + + // Initialize the client + if resp, err := client.Initialize( + context.Background(), + &mcp_golang.InitializeRequestParams{ + ClientInfo: mcp_golang.InitializeRequestClientInfo{ + Name: "example-client", + Version: "0.1.0", + }, + ProtocolVersion: protocolVersion, + Capabilities: nil, + }, + ); err != nil { + log.Fatalf("Failed to initialize client: %v", err) + + } else { + log.Printf("Initialized client: %v", spew.Sdump(resp)) + } + + // List available tools + tools, err := client.ListTools(context.Background(), nil) + if err != nil { + log.Fatalf("Failed to list tools: %v", err) + } + + log.Println("Available Tools:") + for _, tool := range tools.Tools { + desc := "" + if tool.Description != nil { + desc = *tool.Description + } + log.Printf("Tool: %s. Description: %s", tool.Name, desc) + } + + // Call the time tool with different formats + formats := []string{ + time.RFC3339, + "2006-01-02 15:04:05", + "Mon, 02 Jan 2006", + } + + for _, format := range formats { + args := map[string]interface{}{ + "format": format, + } + + response, err := client.CallTool(context.Background(), "time", args) + if err != nil { + log.Printf("Failed to call time tool: %v", err) + continue + } + + if len(response.Content) > 0 && response.Content[0].TextContent != nil { + log.Printf("Time in format %q: %s", format, response.Content[0].TextContent.Text) + } + } + + // When shutting down: + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // 1. First cancel all SSE connections and prevent new ones + log.Println("Cancelling all SSE connections...") + rootCancel() + + // 2. Close the transport (which should prevent new SSE connections) + log.Println("Closing SSE transport...") + if err := sseTransport.Close(); err != nil { + log.Printf("Error closing SSE transport: %v", err) + } + + // 3. Finally shutdown the HTTP server + log.Println("Shutting down HTTP server...") + if err := httpServer.Shutdown(ctx); err != nil { + log.Printf("Error during server shutdown: %v", err) + } + + select { + case <-done: + log.Println("Server shutdown completed") + case <-time.After(10 * time.Second): + log.Println("Server shutdown timed out") + } +} diff --git a/examples/updating_registrations_on_the_fly/updating_registrations_on_the_fly.go b/examples/updating_registrations_on_the_fly/updating_registrations_on_the_fly.go index a52edbe..c9af73f 100644 --- a/examples/updating_registrations_on_the_fly/updating_registrations_on_the_fly.go +++ b/examples/updating_registrations_on_the_fly/updating_registrations_on_the_fly.go @@ -2,8 +2,8 @@ package main import ( "fmt" - mcp_golang "github.com/metoro-io/mcp-golang" - "github.com/metoro-io/mcp-golang/transport/stdio" + mcp_golang "github.com/rvoh-emccaleb/mcp-golang" + "github.com/rvoh-emccaleb/mcp-golang/transport/stdio" "time" ) diff --git a/go.mod b/go.mod index 9e12af5..87d4405 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,10 @@ -module github.com/metoro-io/mcp-golang +module github.com/rvoh-emccaleb/mcp-golang -go 1.21 +go 1.24.0 require ( + github.com/davecgh/go-spew v1.1.1 + github.com/gin-gonic/gin v1.8.1 github.com/invopop/jsonschema v0.12.0 github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.9.0 @@ -12,9 +14,7 @@ require ( require ( github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect github.com/gin-contrib/sse v0.1.0 // indirect - github.com/gin-gonic/gin v1.8.1 // indirect github.com/go-playground/locales v0.14.0 // indirect github.com/go-playground/universal-translator v0.18.0 // indirect github.com/go-playground/validator/v10 v10.10.0 // indirect @@ -32,10 +32,10 @@ require ( github.com/tidwall/pretty v1.2.1 // indirect github.com/ugorji/go/codec v1.2.7 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect - golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 // indirect - golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 // indirect - golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069 // indirect - golang.org/x/text v0.3.6 // indirect + golang.org/x/crypto v0.36.0 // indirect + golang.org/x/net v0.37.0 // indirect + golang.org/x/sys v0.31.0 // indirect + golang.org/x/text v0.23.0 // indirect google.golang.org/protobuf v1.28.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 309c382..9ed58ad 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,7 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.8.1 h1:4+fr/el88TOO3ewCmQr8cx/CtZ/umlIRIs5M4NTNjf8= github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk= +github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU= github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= @@ -20,10 +21,9 @@ github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXS github.com/goccy/go-json v0.9.7 h1:IcB+Aqpx/iMHu5Yooh7jEzJk1JZ7Pjtmys2ukPr7EeM= github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= -github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/invopop/jsonschema v0.12.0 h1:6ovsNSuvn9wEQVOyc72aycBMVQFKz7cPdMJn10CvzRI= github.com/invopop/jsonschema v0.12.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= @@ -31,9 +31,11 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= @@ -53,6 +55,7 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= +github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= @@ -76,32 +79,32 @@ github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0 github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= -go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= -go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= -go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= -go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= -golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw= +golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= +golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069 h1:siQdpVirKtzPhKl3lZWozZraCFObP8S1v6PRp0bLrtU= golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.28.0/go.mod h1:dcIOrVd3mfQKTgrDVQHqCPMWy6lnhfhtX3hLXYVLfRw= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= diff --git a/integration_test.go b/integration_test.go index 30c7040..2ed31d9 100644 --- a/integration_test.go +++ b/integration_test.go @@ -15,7 +15,7 @@ import ( "testing" "time" - "github.com/metoro-io/mcp-golang/transport" + "github.com/rvoh-emccaleb/mcp-golang/transport" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -23,8 +23,8 @@ import ( const testServerCode = `package main import ( - mcp "github.com/metoro-io/mcp-golang" - "github.com/metoro-io/mcp-golang/transport/stdio" + mcp "github.com/rvoh-emccaleb/mcp-golang" + "github.com/rvoh-emccaleb/mcp-golang/transport/stdio" ) type EchoArgs struct { @@ -78,7 +78,7 @@ func TestServerIntegration(t *testing.T) { require.NoError(t, err, "Failed to initialize module: %s", string(output)) // Replace the dependency with the local version - cmd = exec.Command("go", "mod", "edit", "-replace", "github.com/metoro-io/mcp-golang="+currentDir) + cmd = exec.Command("go", "mod", "edit", "-replace", "github.com/rvoh-emccaleb/mcp-golang="+currentDir) cmd.Dir = tmpDir output, err = cmd.CombinedOutput() require.NoError(t, err, "Failed to replace dependency: %s", string(output)) diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index 3172ac4..18b2ff1 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -72,7 +72,7 @@ import ( "sync" "time" - "github.com/metoro-io/mcp-golang/transport" + "github.com/rvoh-emccaleb/mcp-golang/transport" ) const DefaultRequestTimeoutMsec = 60000 diff --git a/internal/protocol/protocol_test.go b/internal/protocol/protocol_test.go index 887f769..86332d4 100644 --- a/internal/protocol/protocol_test.go +++ b/internal/protocol/protocol_test.go @@ -7,8 +7,8 @@ import ( "testing" "time" - "github.com/metoro-io/mcp-golang/internal/testingutils" - "github.com/metoro-io/mcp-golang/transport" + "github.com/rvoh-emccaleb/mcp-golang/internal/testingutils" + "github.com/rvoh-emccaleb/mcp-golang/transport" ) // TestProtocol_Connect verifies the basic connection functionality of the Protocol. diff --git a/internal/testingutils/mock_transport.go b/internal/testingutils/mock_transport.go index 19196d9..d85eb30 100644 --- a/internal/testingutils/mock_transport.go +++ b/internal/testingutils/mock_transport.go @@ -4,7 +4,7 @@ import ( "context" "sync" - "github.com/metoro-io/mcp-golang/transport" + "github.com/rvoh-emccaleb/mcp-golang/transport" ) // MockTransport implements Transport interface for testing diff --git a/resource_template_response_types.go b/resource_template_response_types.go new file mode 100644 index 0000000..72f29f7 --- /dev/null +++ b/resource_template_response_types.go @@ -0,0 +1,51 @@ +package mcp_golang + +// The server's response to a resources/templates/list request from the client. +type ListResourceTemplatesResponse struct { + // Templates corresponds to the JSON schema field "templates". + Templates []*ResourceTemplateSchema `json:"templates" yaml:"templates" mapstructure:"templates"` + // NextCursor is a cursor for pagination. If not nil, there are more templates available. + NextCursor *string `json:"nextCursor,omitempty" yaml:"nextCursor,omitempty" mapstructure:"nextCursor,omitempty"` +} + +// A resource template that the server is capable of instantiating. +type ResourceTemplateSchema struct { + // Annotations corresponds to the JSON schema field "annotations". + Annotations *Annotations `json:"annotations,omitempty" yaml:"annotations,omitempty" mapstructure:"annotations,omitempty"` + + // A description of what this resource template represents. + // + // This can be used by clients to improve the LLM's understanding of available + // resource templates. It can be thought of like a "hint" to the model. + Description *string `json:"description,omitempty" yaml:"description,omitempty" mapstructure:"description,omitempty"` + + // The MIME type of resources created from this template. + MimeType *string `json:"mimeType,omitempty" yaml:"mimeType,omitempty" mapstructure:"mimeType,omitempty"` + + // A human-readable name for this resource template. + // + // This can be used by clients to populate UI elements. + Name string `json:"name" yaml:"name" mapstructure:"name"` + + // The URI of this resource template. + Uri string `json:"uri" yaml:"uri" mapstructure:"uri"` + + // Parameters that can be used to instantiate this template. + Parameters []ResourceTemplateParameter `json:"parameters" yaml:"parameters" mapstructure:"parameters"` +} + +// A parameter that can be used to instantiate a resource template. +type ResourceTemplateParameter struct { + // A description of what this parameter represents. + // + // This can be used by clients to improve the LLM's understanding of the parameter. + Description *string `json:"description,omitempty" yaml:"description,omitempty" mapstructure:"description,omitempty"` + + // A human-readable name for this parameter. + // + // This can be used by clients to populate UI elements. + Name string `json:"name" yaml:"name" mapstructure:"name"` + + // Whether this parameter is required when instantiating the template. + Required *bool `json:"required,omitempty" yaml:"required,omitempty" mapstructure:"required,omitempty"` +} diff --git a/server.go b/server.go index af6f1ec..a3dd71a 100644 --- a/server.go +++ b/server.go @@ -10,10 +10,10 @@ import ( "strings" "github.com/invopop/jsonschema" - "github.com/metoro-io/mcp-golang/internal/datastructures" - "github.com/metoro-io/mcp-golang/internal/protocol" - "github.com/metoro-io/mcp-golang/transport" "github.com/pkg/errors" + "github.com/rvoh-emccaleb/mcp-golang/internal/datastructures" + "github.com/rvoh-emccaleb/mcp-golang/internal/protocol" + "github.com/rvoh-emccaleb/mcp-golang/transport" ) // Here we define the actual MCP server that users will create and run @@ -107,6 +107,7 @@ type Server struct { tools *datastructures.SyncMap[string, *tool] prompts *datastructures.SyncMap[string, *prompt] resources *datastructures.SyncMap[string, *resource] + resourceTemplates *datastructures.SyncMap[string, *resourceTemplate] serverInstructions *string serverName string serverVersion string @@ -134,6 +135,14 @@ type resource struct { Handler func(context.Context) *resourceResponseSent } +type resourceTemplate struct { + Name string + Description string + Uri string + MimeType string + Parameters []ResourceTemplateParameter +} + type ServerOptions func(*Server) func WithProtocol(protocol *protocol.Protocol) ServerOptions { @@ -163,11 +172,12 @@ func WithVersion(version string) ServerOptions { func NewServer(transport transport.Transport, options ...ServerOptions) *Server { server := &Server{ - protocol: protocol.NewProtocol(nil), - transport: transport, - tools: new(datastructures.SyncMap[string, *tool]), - prompts: new(datastructures.SyncMap[string, *prompt]), - resources: new(datastructures.SyncMap[string, *resource]), + protocol: protocol.NewProtocol(nil), + transport: transport, + tools: new(datastructures.SyncMap[string, *tool]), + prompts: new(datastructures.SyncMap[string, *prompt]), + resources: new(datastructures.SyncMap[string, *resource]), + resourceTemplates: new(datastructures.SyncMap[string, *resourceTemplate]), } for _, option := range options { option(server) @@ -200,6 +210,24 @@ func (s *Server) sendToolListChangedNotification() error { return s.protocol.Notification("notifications/tools/list_changed", nil) } +// RegisterToolWithSchema registers a tool with a predefined JSON schema. +// This is an alternative to RegisterTool, which uses reflection on the handler to create the schema. +func (s *Server) RegisterToolWithSchema(name string, description string, handler any, schema *jsonschema.Schema) error { + err := validateToolHandler(handler) + if err != nil { + return err + } + + s.tools.Store(name, &tool{ + Name: name, + Description: description, + Handler: createWrappedToolHandler(handler), + ToolInputSchema: schema, // Use the provided schema directly + }) + + return s.sendToolListChangedNotification() +} + func (s *Server) CheckToolRegistered(name string) bool { _, ok := s.tools.Load(name) return ok @@ -244,8 +272,9 @@ func (s *Server) DeregisterResource(uri string) error { func createWrappedResourceHandler(userHandler any) func(ctx context.Context) *resourceResponseSent { handlerValue := reflect.ValueOf(userHandler) + handlerType := handlerValue.Type() + return func(ctx context.Context) *resourceResponseSent { - handlerType := handlerValue.Type() var args []reflect.Value if handlerType.NumIn() == 1 { args = []reflect.Value{reflect.ValueOf(ctx)} @@ -554,6 +583,15 @@ func (s *Server) Serve() error { pr.SetRequestHandler("prompts/get", s.handlePromptCalls) pr.SetRequestHandler("resources/list", s.handleListResources) pr.SetRequestHandler("resources/read", s.handleResourceCalls) + pr.SetRequestHandler("resources/templates/list", s.handleListResourceTemplates) + + // Note: All notifications in MCP, other than this one, are sent from the server to the client, and can use SSE. + // This is the only notification that is sent from the client to the server. In order to make things work + // smoothly from the client's perspective, we choose to handle this notification as a request instead of + // as a notification, because working with the HTTP protocol is inherently request/response oriented, and the client + // would otherwise have to guess as to how long to wait before assuming the server is initialized. + pr.SetRequestHandler("notifications/initialized", s.handleInitializedNotification) + err := pr.Connect(s.transport) if err != nil { return err @@ -576,6 +614,16 @@ func (s *Server) handleInitialize(ctx context.Context, request *transport.BaseJS }, nil } +// handleInitializedNotification is a request handler for the "notifications/initialized" notification. +// This is the only notification in MCP that is sent from the client to the server. In order to make things work +// smoothly from the client's perspective, we choose to handle this notification as a request instead of +// as a notification, because working with the HTTP protocol is inherently request/response oriented, and the client +// would otherwise have to guess as to how long to wait before assuming the server is initialized. We just return +// an empty response body to the client to indicate that the notification has been received. +func (s *Server) handleInitializedNotification(ctx context.Context, request *transport.BaseJSONRPCRequest, _ protocol.RequestHandlerExtra) (transport.JsonRpcBody, error) { + return map[string]interface{}{}, nil +} + func (s *Server) handleListTools(ctx context.Context, request *transport.BaseJSONRPCRequest, _ protocol.RequestHandlerExtra) (transport.JsonRpcBody, error) { type toolRequestParams struct { Cursor *string `json:"cursor"` @@ -698,9 +746,13 @@ func (s *Server) handleListPrompts(ctx context.Context, request *transport.BaseJ Cursor *string `json:"cursor"` } var params promptRequestParams - err := json.Unmarshal(request.Params, ¶ms) - if err != nil { - return nil, errors.Wrap(err, "failed to unmarshal arguments") + if request.Params == nil { + params = promptRequestParams{} + } else { + err := json.Unmarshal(request.Params, ¶ms) + if err != nil { + return nil, errors.Wrap(err, "failed to unmarshal arguments") + } } // Order by name for pagination @@ -762,9 +814,13 @@ func (s *Server) handleListResources(ctx context.Context, request *transport.Bas Cursor *string `json:"cursor"` } var params resourceRequestParams - err := json.Unmarshal(request.Params, ¶ms) - if err != nil { - return nil, errors.Wrap(err, "failed to unmarshal arguments") + if request.Params == nil { + params = resourceRequestParams{} + } else { + err := json.Unmarshal(request.Params, ¶ms) + if err != nil { + return nil, errors.Wrap(err, "failed to unmarshal arguments") + } } // Order by URI for pagination @@ -871,6 +927,78 @@ func (s *Server) handleResourceCalls(ctx context.Context, req *transport.BaseJSO return resourceToUse.Handler(ctx), nil } +func (s *Server) handleListResourceTemplates(ctx context.Context, request *transport.BaseJSONRPCRequest, extra protocol.RequestHandlerExtra) (transport.JsonRpcBody, error) { + type templateRequestParams struct { + Cursor *string `json:"cursor"` + } + var params templateRequestParams + if request.Params == nil { + params = templateRequestParams{} + } else { + err := json.Unmarshal(request.Params, ¶ms) + if err != nil { + return nil, errors.Wrap(err, "failed to unmarshal arguments") + } + } + + // Order by URI for pagination + var orderedTemplates []*resourceTemplate + s.resourceTemplates.Range(func(k string, t *resourceTemplate) bool { + orderedTemplates = append(orderedTemplates, t) + return true + }) + sort.Slice(orderedTemplates, func(i, j int) bool { + return orderedTemplates[i].Uri < orderedTemplates[j].Uri + }) + + startPosition := 0 + if params.Cursor != nil { + // Base64 decode the cursor + c, err := base64.StdEncoding.DecodeString(*params.Cursor) + if err != nil { + return nil, errors.Wrap(err, "failed to decode cursor") + } + cString := string(c) + // Iterate through the templates until we find an entry > the cursor + for i := 0; i < len(orderedTemplates); i++ { + if orderedTemplates[i].Uri > cString { + startPosition = i + break + } + } + } + endPosition := len(orderedTemplates) + if s.paginationLimit != nil { + // Make sure we don't go out of bounds + if len(orderedTemplates) > startPosition+*s.paginationLimit { + endPosition = startPosition + *s.paginationLimit + } + } + + templatesToReturn := make([]*ResourceTemplateSchema, 0) + for i := startPosition; i < endPosition; i++ { + t := orderedTemplates[i] + templatesToReturn = append(templatesToReturn, &ResourceTemplateSchema{ + Description: &t.Description, + MimeType: &t.MimeType, + Name: t.Name, + Uri: t.Uri, + Parameters: t.Parameters, + }) + } + + return ListResourceTemplatesResponse{ + Templates: templatesToReturn, + NextCursor: func() *string { + if s.paginationLimit != nil && len(templatesToReturn) >= *s.paginationLimit { + toString := base64.StdEncoding.EncodeToString([]byte(templatesToReturn[len(templatesToReturn)-1].Uri)) + return &toString + } + return nil + }(), + }, nil +} + func (s *Server) handlePing(ctx context.Context, request *transport.BaseJSONRPCRequest, extra protocol.RequestHandlerExtra) (transport.JsonRpcBody, error) { return map[string]interface{}{}, nil } @@ -927,3 +1055,28 @@ var ( CommentMap: nil, } ) + +func (s *Server) RegisterResourceTemplate(uri string, name string, description string, mimeType string, parameters []ResourceTemplateParameter) error { + s.resourceTemplates.Store(uri, &resourceTemplate{ + Name: name, + Description: description, + Uri: uri, + MimeType: mimeType, + Parameters: parameters, + }) + return s.sendResourceTemplateListChangedNotification() +} + +func (s *Server) DeregisterResourceTemplate(uri string) error { + s.resourceTemplates.Delete(uri) + return s.sendResourceTemplateListChangedNotification() +} + +func (s *Server) sendResourceTemplateListChangedNotification() error { + if !s.isRunning { + return nil + } + + // Note: Unsure if this one is in spec. + return s.protocol.Notification("notifications/resources/templates/list_changed", nil) +} diff --git a/server_test.go b/server_test.go index 2da0c82..8634d4d 100644 --- a/server_test.go +++ b/server_test.go @@ -4,9 +4,9 @@ import ( "context" "testing" - "github.com/metoro-io/mcp-golang/internal/protocol" - "github.com/metoro-io/mcp-golang/internal/testingutils" - "github.com/metoro-io/mcp-golang/transport" + "github.com/rvoh-emccaleb/mcp-golang/internal/protocol" + "github.com/rvoh-emccaleb/mcp-golang/internal/testingutils" + "github.com/rvoh-emccaleb/mcp-golang/transport" ) func TestServerListChangedNotifications(t *testing.T) { diff --git a/transport/base/base.go b/transport/base/base.go new file mode 100644 index 0000000..7baf145 --- /dev/null +++ b/transport/base/base.go @@ -0,0 +1,144 @@ +package base + +import ( + "context" + "encoding/json" + "fmt" + "sync" + + "github.com/rvoh-emccaleb/mcp-golang/transport" +) + +// Transport implements the common functionality for transports +type Transport struct { + messageHandler func(ctx context.Context, message *transport.BaseJsonRpcMessage) + errorHandler func(error) + closeHandler func() + mu sync.RWMutex + ResponseMap map[int64]chan *transport.BaseJsonRpcMessage +} + +// NewTransport creates a new base transport +func NewTransport() *Transport { + return &Transport{ + ResponseMap: make(map[int64]chan *transport.BaseJsonRpcMessage), + } +} + +// Send implements Transport.Send +func (t *Transport) Send(ctx context.Context, message *transport.BaseJsonRpcMessage) error { + key := message.JsonRpcResponse.Id + responseChannel := t.ResponseMap[int64(key)] + if responseChannel == nil { + return fmt.Errorf("no response channel found for key: %d", key) + } + responseChannel <- message + return nil +} + +// Close implements Transport.Close +func (t *Transport) Close() error { + if t.closeHandler != nil { + t.closeHandler() + } + return nil +} + +// SetCloseHandler implements Transport.SetCloseHandler +func (t *Transport) SetCloseHandler(handler func()) { + t.mu.Lock() + defer t.mu.Unlock() + t.closeHandler = handler +} + +// SetErrorHandler implements Transport.SetErrorHandler +func (t *Transport) SetErrorHandler(handler func(error)) { + t.mu.Lock() + defer t.mu.Unlock() + t.errorHandler = handler +} + +// SetMessageHandler implements Transport.SetMessageHandler +func (t *Transport) SetMessageHandler(handler func(ctx context.Context, message *transport.BaseJsonRpcMessage)) { + t.mu.Lock() + defer t.mu.Unlock() + t.messageHandler = handler +} + +// HandleMessage processes an incoming message and returns a response +func (t *Transport) HandleMessage(ctx context.Context, body []byte) (*transport.BaseJsonRpcMessage, error) { + // Try to unmarshal as a request first + var request transport.BaseJSONRPCRequest + if err := json.Unmarshal(body, &request); err == nil { + // Create a response channel for this request + t.mu.Lock() + var key int64 = 0 + for key < 1000000 { + if _, ok := t.ResponseMap[key]; !ok { + break + } + key = key + 1 + } + t.ResponseMap[key] = make(chan *transport.BaseJsonRpcMessage) + t.mu.Unlock() + + originalID := request.Id + request.Id = transport.RequestId(key) + t.mu.RLock() + handler := t.messageHandler + t.mu.RUnlock() + + if handler != nil { + handler(ctx, transport.NewBaseMessageRequest(&request)) + } + + // Block until the response is received + responseToUse := <-t.ResponseMap[key] + delete(t.ResponseMap, key) + + // Restore the original client ID in the response + responseToUse.JsonRpcResponse.Id = originalID + return responseToUse, nil + } + + // Try as a notification + var notification transport.BaseJSONRPCNotification + if err := json.Unmarshal(body, ¬ification); err == nil { + t.mu.RLock() + handler := t.messageHandler + t.mu.RUnlock() + + if handler != nil { + handler(ctx, transport.NewBaseMessageNotification(¬ification)) + } + return nil, nil + } + + // Try as a response + var response transport.BaseJSONRPCResponse + if err := json.Unmarshal(body, &response); err == nil { + t.mu.RLock() + handler := t.messageHandler + t.mu.RUnlock() + + if handler != nil { + handler(ctx, transport.NewBaseMessageResponse(&response)) + } + return nil, nil + } + + // Try as an error + var errorResponse transport.BaseJSONRPCError + if err := json.Unmarshal(body, &errorResponse); err == nil { + t.mu.RLock() + handler := t.messageHandler + t.mu.RUnlock() + + if handler != nil { + handler(ctx, transport.NewBaseMessageError(&errorResponse)) + } + return nil, nil + } + + return nil, fmt.Errorf("failed to unmarshal JSON-RPC message, unrecognized type") +} diff --git a/transport/http/common.go b/transport/http/common.go index 30e6d72..49d3af2 100644 --- a/transport/http/common.go +++ b/transport/http/common.go @@ -7,7 +7,7 @@ import ( "io" "sync" - "github.com/metoro-io/mcp-golang/transport" + "github.com/rvoh-emccaleb/mcp-golang/transport" ) // baseTransport implements the common functionality for HTTP-based transports diff --git a/transport/http/gin.go b/transport/http/gin.go index bbc6c7c..43c29ba 100644 --- a/transport/http/gin.go +++ b/transport/http/gin.go @@ -7,7 +7,7 @@ import ( "net/http" "github.com/gin-gonic/gin" - "github.com/metoro-io/mcp-golang/transport" + "github.com/rvoh-emccaleb/mcp-golang/transport" ) // GinTransport implements a stateless HTTP transport for MCP using Gin diff --git a/transport/http/http.go b/transport/http/http.go index 570ae1b..51e6c19 100644 --- a/transport/http/http.go +++ b/transport/http/http.go @@ -5,22 +5,14 @@ import ( "encoding/json" "fmt" "net/http" - "sync" - - "github.com/metoro-io/mcp-golang/transport" ) // HTTPTransport implements a stateless HTTP transport for MCP type HTTPTransport struct { *baseTransport - server *http.Server - endpoint string - messageHandler func(ctx context.Context, message *transport.BaseJsonRpcMessage) - errorHandler func(error) - closeHandler func() - mu sync.RWMutex - addr string - responseMap map[int64]chan *transport.BaseJsonRpcMessage + server *http.Server + endpoint string + addr string } // NewHTTPTransport creates a new HTTP transport that listens on the specified endpoint @@ -29,7 +21,6 @@ func NewHTTPTransport(endpoint string) *HTTPTransport { baseTransport: newBaseTransport(), endpoint: endpoint, addr: ":8080", // Default port - responseMap: make(map[int64]chan *transport.BaseJsonRpcMessage), } } @@ -52,17 +43,6 @@ func (t *HTTPTransport) Start(ctx context.Context) error { return t.server.ListenAndServe() } -// Send implements Transport.Send -func (t *HTTPTransport) Send(ctx context.Context, message *transport.BaseJsonRpcMessage) error { - key := message.JsonRpcResponse.Id - responseChannel := t.responseMap[int64(key)] - if responseChannel == nil { - return fmt.Errorf("no response channel found for key: %d", key) - } - responseChannel <- message - return nil -} - // Close implements Transport.Close func (t *HTTPTransport) Close() error { if t.server != nil { @@ -70,31 +50,8 @@ func (t *HTTPTransport) Close() error { return err } } - if t.closeHandler != nil { - t.closeHandler() - } - return nil -} - -// SetCloseHandler implements Transport.SetCloseHandler -func (t *HTTPTransport) SetCloseHandler(handler func()) { - t.mu.Lock() - defer t.mu.Unlock() - t.closeHandler = handler -} - -// SetErrorHandler implements Transport.SetErrorHandler -func (t *HTTPTransport) SetErrorHandler(handler func(error)) { - t.mu.Lock() - defer t.mu.Unlock() - t.errorHandler = handler -} -// SetMessageHandler implements Transport.SetMessageHandler -func (t *HTTPTransport) SetMessageHandler(handler func(ctx context.Context, message *transport.BaseJsonRpcMessage)) { - t.mu.Lock() - defer t.mu.Unlock() - t.messageHandler = handler + return t.baseTransport.Close() } func (t *HTTPTransport) handleRequest(w http.ResponseWriter, r *http.Request) { diff --git a/transport/http/http_client.go b/transport/http/http_client.go index f01416a..f17474a 100644 --- a/transport/http/http_client.go +++ b/transport/http/http_client.go @@ -4,32 +4,45 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net/http" "sync" + "time" - "github.com/metoro-io/mcp-golang/transport" + "github.com/rvoh-emccaleb/mcp-golang/transport" ) // HTTPClientTransport implements a client-side HTTP transport for MCP type HTTPClientTransport struct { - baseURL string - endpoint string - messageHandler func(ctx context.Context, message *transport.BaseJsonRpcMessage) - errorHandler func(error) - closeHandler func() - mu sync.RWMutex - client *http.Client - headers map[string]string + baseURL string + endpoint string + messageHandler func(ctx context.Context, message *transport.BaseJsonRpcMessage) + errorHandler func(error) + closeHandler func() + mu sync.RWMutex + client *http.Client + notificationClient *http.Client + headers map[string]string } // NewHTTPClientTransport creates a new HTTP client transport that connects to the specified endpoint -func NewHTTPClientTransport(endpoint string) *HTTPClientTransport { +func NewHTTPClientTransport(endpoint string, notificationTimeout time.Duration) *HTTPClientTransport { + if notificationTimeout <= 0 { + notificationTimeout = 1 * time.Millisecond // This is flaky, but it works for local demos. + } + return &HTTPClientTransport{ endpoint: endpoint, client: &http.Client{}, - headers: make(map[string]string), + notificationClient: &http.Client{ + Transport: &http.Transport{ + DisableKeepAlives: true, + }, + Timeout: notificationTimeout, + }, + headers: make(map[string]string), } } @@ -68,6 +81,27 @@ func (t *HTTPClientTransport) Send(ctx context.Context, message *transport.BaseJ req.Header.Set(key, value) } + // Note: The client usually doesn't send notifications. Really it's only used + // for the "notifications/initialized" method. The server may or may not be implemented + // to return a response, so we have to rely on having a long enough timeout to ensure the + // server has time to process the notification. This is inherently flaky, and should be + // improved upon. + if message.Type == transport.BaseMessageTypeJSONRPCNotificationType { + resp, err := t.notificationClient.Do(req) + if err != nil && !errors.Is(err, context.DeadlineExceeded) { + if t.errorHandler != nil { + t.errorHandler(fmt.Errorf("notification error: %w", err)) + } + } + if resp != nil { + defer resp.Body.Close() + } + + return nil + } + + // For non-notifications, continue with normal synchronous request + resp, err := t.client.Do(req) if err != nil { return fmt.Errorf("failed to send request: %w", err) diff --git a/transport/sse/client.go b/transport/sse/client.go new file mode 100644 index 0000000..180bab7 --- /dev/null +++ b/transport/sse/client.go @@ -0,0 +1,345 @@ +package sse + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/rvoh-emccaleb/mcp-golang/transport" +) + +// SSEClientTransport implements a client-side SSE transport for MCP +type SSEClientTransport struct { + baseURL string + endpoint string + postEndpoint string + messageHandler func(ctx context.Context, message *transport.BaseJsonRpcMessage) + errorHandler func(error) + closeHandler func() + mu sync.RWMutex + client *http.Client + notificationClient *http.Client + headers map[string]string + done chan struct{} +} + +// NewSSEClientTransport creates a new SSE client transport that connects to the specified endpoint +func NewSSEClientTransport(endpoint string, notificationTimeout time.Duration) *SSEClientTransport { + if notificationTimeout <= 0 { + notificationTimeout = 1 * time.Millisecond // This is flaky, but it works for local demos. + } + + return &SSEClientTransport{ + endpoint: endpoint, + client: &http.Client{}, + notificationClient: &http.Client{ + Transport: &http.Transport{ + DisableKeepAlives: true, + }, + Timeout: notificationTimeout, + }, + headers: make(map[string]string), + done: make(chan struct{}), + } +} + +// WithBaseURL sets the base URL to connect to +func (t *SSEClientTransport) WithBaseURL(baseURL string) *SSEClientTransport { + t.baseURL = baseURL + return t +} + +// WithHeader adds a header to the request +func (t *SSEClientTransport) WithHeader(key, value string) *SSEClientTransport { + t.headers[key] = value + return t +} + +// Start implements Transport.Start +func (t *SSEClientTransport) Start(ctx context.Context) error { + url, err := url.JoinPath(t.baseURL, t.endpoint) + if err != nil { + return fmt.Errorf("failed to construct URL: %w", err) + } + + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + // Set required SSE headers + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Connection", "keep-alive") + for key, value := range t.headers { + req.Header.Set(key, value) + } + + endpointChan := make(chan struct{}, 1) + + // Establish our SSE connection and start reading SSE messages in the background + go func() { + // Response should return once we receive headers (body can have data written to it over time). + resp, err := t.client.Do(req) + if err != nil { + if t.errorHandler != nil { + t.errorHandler(fmt.Errorf("failed to establish SSE connection: %w", err)) + } + return + } + + // Validate SSE connection + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + if t.errorHandler != nil { + t.errorHandler(fmt.Errorf("server returned error status %d: %s", resp.StatusCode, string(body))) + } + resp.Body.Close() + return + } + + contentType := resp.Header.Get("Content-Type") + if !strings.Contains(contentType, "text/event-stream") { + if t.errorHandler != nil { + t.errorHandler(fmt.Errorf("invalid content type for SSE connection: %s", contentType)) + } + resp.Body.Close() + return + } + + // Start reading and storing SSE messages + t.readSSEMessages(resp.Body, endpointChan) + }() + + // Wait for either the endpoint event or timeout + select { + case <-endpointChan: + return nil + case <-time.After(10 * time.Second): + return fmt.Errorf("timeout waiting for endpoint event") + case <-ctx.Done(): + return fmt.Errorf("context cancelled while waiting for endpoint event: %w", ctx.Err()) + } +} + +// readSSEMessages reads SSE messages from the response body and processes them +func (t *SSEClientTransport) readSSEMessages(body io.ReadCloser, endpointChan chan<- struct{}) { + defer body.Close() + + reader := bufio.NewReader(body) + buffer := make([]byte, 4096) + var messageBuffer bytes.Buffer + + for { + select { + case <-t.done: + return + default: + n, err := reader.Read(buffer) + if err != nil { + if err != io.EOF { + if t.errorHandler != nil { + t.errorHandler(fmt.Errorf("error reading SSE stream: %w", err)) + } + } + return + } + + messageBuffer.Write(buffer[:n]) + + // Process complete messages (terminated by double newline) + for { + // Check if we have at least one complete message (terminated by double newline) + content := messageBuffer.String() + messageEnd := strings.Index(content, "\n\n") + if messageEnd == -1 { + // No complete message yet, keep reading + break + } + + message := content[:messageEnd] + + messageBuffer.Reset() + messageBuffer.WriteString(content[messageEnd+2:]) // 2 newlines + + // Parse the complete message + msg, err := t.parseSSEMessageFromString(message, endpointChan) + if err != nil { + if t.errorHandler != nil { + t.errorHandler(fmt.Errorf("error parsing SSE message: %w", err)) + } + continue + } + + // Handle the message based on its type. + // Note: The initial endpoint event doesn't require handling, here. + // We only need to handle the response and notification types. + switch msg.Type { + case transport.BaseMessageTypeJSONRPCResponseType, + transport.BaseMessageTypeJSONRPCNotificationType, + transport.BaseMessageTypeJSONRPCErrorType: + + t.mu.RLock() + handler := t.messageHandler + t.mu.RUnlock() + + if handler != nil { + handler(context.Background(), msg) + } + } + } + } + } +} + +// parseSSEMessageFromString parses a complete SSE message from a string +func (t *SSEClientTransport) parseSSEMessageFromString(message string, endpointChan chan<- struct{}) (*transport.BaseJsonRpcMessage, error) { + lines := strings.Split(message, "\n") + if len(lines) < 2 { + return nil, fmt.Errorf("invalid SSE message format: too few lines") + } + + // Parse the event type + eventType := strings.TrimPrefix(lines[0], "event: ") + eventType = strings.TrimSpace(eventType) + + // Parse the data line + data := strings.TrimPrefix(lines[1], "data: ") + data = strings.TrimSpace(data) + + // Initialize message + var msg transport.BaseJsonRpcMessage + + // Handle endpoint events differently + if eventType == "endpoint" { + t.postEndpoint = data + if endpointChan != nil { + endpointChan <- struct{}{} + } + return &msg, nil + } + + // For message events, try to unmarshal as JSON RPC + var response transport.BaseJSONRPCResponse + if err := json.Unmarshal([]byte(data), &response); err == nil { + msg.Type = transport.BaseMessageTypeJSONRPCResponseType + msg.JsonRpcResponse = &response + return &msg, nil + } + + // Try as an error + var errorResponse transport.BaseJSONRPCError + if err := json.Unmarshal([]byte(data), &errorResponse); err == nil { + msg.Type = transport.BaseMessageTypeJSONRPCErrorType + msg.JsonRpcError = &errorResponse + return &msg, nil + } + + // Try as a notification + var notification transport.BaseJSONRPCNotification + if err := json.Unmarshal([]byte(data), ¬ification); err == nil { + msg.Type = transport.BaseMessageTypeJSONRPCNotificationType + msg.JsonRpcNotification = ¬ification + return &msg, nil + } + + return nil, fmt.Errorf("unrecognized message type: %s", eventType) +} + +// Send implements Transport.Send +func (t *SSEClientTransport) Send(ctx context.Context, message *transport.BaseJsonRpcMessage) error { + jsonData, err := json.Marshal(message) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + if t.postEndpoint == "" { + return fmt.Errorf("post endpoint not set. sse connection not established") + } + + url, err := url.JoinPath(t.baseURL, t.postEndpoint) + if err != nil { + return fmt.Errorf("failed to construct URL: %w", err) + } + + req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + for key, value := range t.headers { + req.Header.Set(key, value) + } + + // Handle notifications differently with a shorter timeout + if message.Type == transport.BaseMessageTypeJSONRPCNotificationType { + resp, err := t.notificationClient.Do(req) + if err != nil && !errors.Is(err, context.DeadlineExceeded) { + t.mu.RLock() + if t.errorHandler != nil { + t.errorHandler(fmt.Errorf("notification error: %w", err)) + } + t.mu.RUnlock() + } + if resp != nil { + defer resp.Body.Close() + } + return nil + } + + // For non-notifications, continue with normal synchronous request + resp, err := t.client.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("server returned error: %s (status: %d)", string(body), resp.StatusCode) + } + + // Response content we care about will come in over the SSE stream. + // We are done, here. + + return nil +} + +// Close implements Transport.Close +func (t *SSEClientTransport) Close() error { + close(t.done) + if t.closeHandler != nil { + t.closeHandler() + } + return nil +} + +// SetCloseHandler implements Transport.SetCloseHandler +func (t *SSEClientTransport) SetCloseHandler(handler func()) { + t.mu.Lock() + defer t.mu.Unlock() + t.closeHandler = handler +} + +// SetErrorHandler implements Transport.SetErrorHandler +func (t *SSEClientTransport) SetErrorHandler(handler func(error)) { + t.mu.Lock() + defer t.mu.Unlock() + t.errorHandler = handler +} + +// SetMessageHandler implements Transport.SetMessageHandler +func (t *SSEClientTransport) SetMessageHandler(handler func(ctx context.Context, message *transport.BaseJsonRpcMessage)) { + t.mu.Lock() + defer t.mu.Unlock() + t.messageHandler = handler +} diff --git a/transport/sse/internal/sse/sse.go b/transport/sse/internal/sse/sse.go deleted file mode 100644 index d42d804..0000000 --- a/transport/sse/internal/sse/sse.go +++ /dev/null @@ -1,244 +0,0 @@ -// /* -// Package mcp implements Server-Sent Events (SSE) transport for JSON-RPC communication. -// -// SSE Transport Overview: -// This implementation provides a bidirectional communication channel between client and server: -// - Server to Client: Uses Server-Sent Events (SSE) for real-time message streaming -// - Client to Server: Uses HTTP POST requests for sending messages -// -// Key Features: -// 1. Bidirectional Communication: -// - SSE for server-to-client streaming (one-way, real-time updates) -// - HTTP POST endpoints for client-to-server messages -// -// 2. Session Management: -// - Unique session IDs for each connection -// - Proper connection lifecycle management -// - Automatic cleanup on connection close -// -// 3. Message Handling: -// - JSON-RPC message format support -// - Automatic message type detection (request vs response) -// - Built-in error handling and reporting -// - Message size limits for security -// -// 4. Security Features: -// - Content-type validation -// - Message size limits (4MB default) -// - Error handling for malformed messages -// -// Usage Example: -// -// // Create a new SSE transport -// transport, err := NewSSETransport("/messages", responseWriter) -// if err != nil { -// log.Fatal(err) -// } -// -// // Set up message handling -// transport.SetMessageHandler(func(msg JSONRPCMessage) { -// // Handle incoming messages -// }) -// -// // Start the SSE connection -// if err := transport.Start(context.Background()); err != nil { -// log.Fatal(err) -// } -// -// // Send a message -// msg := JSONRPCResponse{ -// Jsonrpc: "2.0", -// Result: Result{...}, -// Id: 1, -// } -// if err := transport.Send(msg); err != nil { -// log.Fatal(err) -// } -// -// */ -package sse - -// -//import ( -// "context" -// "encoding/json" -// "fmt" -// "github.com/metoro-io/mcp-golang/transport" -// "net/http" -// "sync" -// -// "github.com/google/uuid" -//) -// -//const ( -// maxMessageSize = 4 * 1024 * 1024 // 4MB -//) -// -//// SSETransport implements a Server-Sent Events transport for JSON-RPC messages -//type SSETransport struct { -// endpoint string -// sessionID string -// writer http.ResponseWriter -// flusher http.Flusher -// mu sync.Mutex -// isConnected bool -// -// // Callbacks -// closeHandler func() -// errorHandler func(error) -// messageHandler func(message *transport.BaseJsonRpcMessage) -//} -// -//// NewSSETransport creates a new SSE transport with the given endpoint and response writer -//func NewSSETransport(endpoint string, w http.ResponseWriter) (*SSETransport, error) { -// flusher, ok := w.(http.Flusher) -// if !ok { -// return nil, fmt.Errorf("streaming not supported") -// } -// -// return &SSETransport{ -// endpoint: endpoint, -// sessionID: uuid.New().String(), -// writer: w, -// flusher: flusher, -// }, nil -//} -// -//// Start initializes the SSE connection -//func (t *SSETransport) Start(ctx context.Context) error { -// t.mu.Lock() -// defer t.mu.Unlock() -// -// if t.isConnected { -// return fmt.Errorf("SSE transport already started") -// } -// -// // Set SSE headers -// h := t.writer.Header() -// h.Set("Content-Type", "text/event-stream") -// h.Set("Cache-Control", "no-cache") -// h.Set("Connection", "keep-alive") -// h.Set("Access-Control-Allow-Origin", "*") -// -// // Send the endpoint event -// endpointURL := fmt.Sprintf("%s?sessionId=%s", t.endpoint, t.sessionID) -// if err := t.writeEvent("endpoint", endpointURL); err != nil { -// return err -// } -// -// t.isConnected = true -// -// // Handle context cancellation -// go func() { -// <-ctx.Done() -// t.Close() -// }() -// -// return nil -//} -// -//// HandleMessage processes an incoming message -//func (t *SSETransport) HandleMessage(msg []byte) error { -// var rpcMsg map[string]interface{} -// if err := json.Unmarshal(msg, &rpcMsg); err != nil { -// if t.errorHandler != nil { -// t.errorHandler(err) -// } -// return err -// } -// -// // Parse as a JSONRPCMessage -// var jsonrpcMsg JSONRPCMessage -// if _, ok := rpcMsg["method"]; ok { -// var req JSONRPCRequest -// if err := json.Unmarshal(msg, &req); err != nil { -// if t.errorHandler != nil { -// t.errorHandler(err) -// } -// return err -// } -// jsonrpcMsg = &req -// } else { -// var resp JSONRPCResponse -// if err := json.Unmarshal(msg, &resp); err != nil { -// if t.errorHandler != nil { -// t.errorHandler(err) -// } -// return err -// } -// jsonrpcMsg = &resp -// } -// -// if t.messageHandler != nil { -// t.messageHandler(jsonrpcMsg) -// } -// return nil -//} -// -//// Send sends a message over the SSE connection -//func (t *SSETransport) Send(msg JSONRPCMessage) error { -// t.mu.Lock() -// defer t.mu.Unlock() -// -// if !t.isConnected { -// return fmt.Errorf("not connected") -// } -// -// data, err := json.Marshal(msg) -// if err != nil { -// return err -// } -// -// return t.writeEvent("message", string(data)) -//} -// -//// Close closes the SSE connection -//func (t *SSETransport) Close() error { -// t.mu.Lock() -// defer t.mu.Unlock() -// -// if !t.isConnected { -// return nil -// } -// -// t.isConnected = false -// if t.closeHandler != nil { -// t.closeHandler() -// } -// return nil -//} -// -//// SetCloseHandler sets the callback for when the connection is closed -//func (t *SSETransport) SetCloseHandler(handler func()) { -// t.mu.Lock() -// defer t.mu.Unlock() -// t.closeHandler = handler -//} -// -//// SetErrorHandler sets the callback for when an error occurs -//func (t *SSETransport) SetErrorHandler(handler func(error)) { -// t.mu.Lock() -// defer t.mu.Unlock() -// t.errorHandler = handler -//} -// -//// SetMessageHandler sets the callback for when a message is received -//func (t *SSETransport) SetMessageHandler(handler func(JSONRPCMessage)) { -// t.mu.Lock() -// defer t.mu.Unlock() -// t.messageHandler = handler -//} -// -//// SessionID returns the unique session identifier for this transport -//func (t *SSETransport) SessionID() string { -// return t.sessionID -//} -// -//// writeEvent writes an SSE event with the given event type and data -//func (t *SSETransport) writeEvent(event, data string) error { -// if _, err := fmt.Fprintf(t.writer, "event: %s\ndata: %s\n\n", event, data); err != nil { -// return err -// } -// t.flusher.Flush() -// return nil -//} diff --git a/transport/sse/server.go b/transport/sse/server.go new file mode 100644 index 0000000..eba71c2 --- /dev/null +++ b/transport/sse/server.go @@ -0,0 +1,332 @@ +package sse + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "sync" + "sync/atomic" + + "github.com/rvoh-emccaleb/mcp-golang/transport" + "github.com/rvoh-emccaleb/mcp-golang/transport/base" +) + +var ( + ErrConnectionRemoved = errors.New("sse connection removed") +) + +// sseConnection represents a single SSE client sseConnection, and provides a way to send messages to the client. +type sseConnection struct { + id int64 + writer http.ResponseWriter + flusher http.Flusher +} + +// ServerTransport implements server-side SSE transport +type ServerTransport struct { + *base.Transport // shared transport for handling all messages + baseEndpoint string // connection IDs are appended to this endpoint for SSE connections + mu sync.RWMutex + sseConns map[int64]*sseConnection // map of connection IDs to connections + nextConnID int64 // atomic counter for generating connection IDs + chunkSize int // size of chunks for writing SSE messages +} + +// Option is a function that configures a ServerTransport +type Option func(*ServerTransport) + +// WithChunkSize sets the chunk size for writing SSE messages +func WithChunkSize(size int) Option { + return func(t *ServerTransport) { + if size > 0 { + t.chunkSize = size + } + } +} + +// NewServerTransport creates a new SSE server transport +func NewServerTransport(endpoint string, options ...Option) *ServerTransport { + t := &ServerTransport{ + Transport: base.NewTransport(), + baseEndpoint: endpoint, + sseConns: make(map[int64]*sseConnection), + chunkSize: 1024, // default to 1KB chunks + } + for _, opt := range options { + opt(t) + } + return t +} + +// HandleSSEConnInitialize handles a new SSE connection request +func (t *ServerTransport) HandleSSEConnInitialize(w http.ResponseWriter) (int64, error) { + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return 0, fmt.Errorf("streaming is not supported: provided http.ResponseWriter does not implement http.Flusher") + } + + // Generate a unique endpoint URI for this connection + connID := atomic.AddInt64(&t.nextConnID, 1) + endpointURI := fmt.Sprintf("%s/%d", t.baseEndpoint, connID) + + sseConn := &sseConnection{ + id: connID, + writer: w, + flusher: flusher, + } + + t.addSSEConnection(sseConn) + + // Set headers for SSE before writing any data + h := sseConn.writer.Header() + h.Set("Content-Type", "text/event-stream") + h.Set("Cache-Control", "no-cache") + h.Set("Connection", "keep-alive") + h.Set("Access-Control-Allow-Origin", "*") + sseConn.writer.WriteHeader(http.StatusOK) + + // Send the initial endpoint event as required by the MCP specification + err := t.writeEndpointEvent(sseConn, endpointURI) + if err != nil { + // We can't change the status code now, so bubble the error up, allowing request + // handling to continue without sending a response to the client. + return 0, err + } + + return connID, nil +} + +// writeEndpointEvent writes the endpoint URI over SSE +func (t *ServerTransport) writeEndpointEvent(sseConn *sseConnection, endpointURI string) error { + _, err := sseConn.writer.Write(fmt.Appendf(nil, "event: endpoint\ndata: %s\n\n", endpointURI)) + if err != nil { + t.RemoveSSEConnection(sseConn.id) + return fmt.Errorf( + "failed to write endpoint event for connection with id %d: %w. %w", + sseConn.id, + err, + ErrConnectionRemoved, + ) + } + + sseConn.flusher.Flush() + + return nil +} + +// HandleMCPMessage handles incoming MCP (JSON-RPC) messages. +func (t *ServerTransport) HandleMCPMessage(w http.ResponseWriter, r *http.Request, connID int64) error { + // First check if we have an active connection for this ID + sseConn, ok := t.getSSEConnection(connID) + if !ok { + errMsg := fmt.Sprintf("no active connection found for id: %d", connID) + http.Error(w, errMsg, http.StatusNotFound) + return errors.New(errMsg) + } + + if r.Header.Get("Content-Type") != "application/json" { + errMsg := fmt.Sprintf("unsupported content type: %s", r.Header.Get("Content-Type")) + http.Error(w, errMsg, http.StatusUnsupportedMediaType) + return errors.New(errMsg) + } + + body, err := io.ReadAll(r.Body) + if err != nil { + errMsg := "failed to read request body" + http.Error(w, errMsg, http.StatusBadRequest) + return fmt.Errorf("%s: %w", errMsg, err) + } + defer r.Body.Close() + + var receivedMsg transport.BaseJsonRpcMessage + if err := json.Unmarshal(body, &receivedMsg); err != nil { + errMsg := "failed to parse request body as a JSON-RPC message" + http.Error(w, errMsg, http.StatusBadRequest) + return fmt.Errorf("%s: %w", errMsg, err) + } + + // Send 200 OK response back to the POST request after validating the message. + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte{}); err != nil { + return fmt.Errorf("failed to write response: %w", err) + } + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + // Note: Errors from this point on must be sent over the SSE connection. + + response, err := t.HandleMessage(r.Context(), body) + if err != nil { + switch receivedMsg.Type { + case transport.BaseMessageTypeJSONRPCRequestType: + newErr := t.writeMessageEvent(sseConn, transport.NewBaseMessageError(&transport.BaseJSONRPCError{ + Jsonrpc: "2.0", + Id: receivedMsg.JsonRpcRequest.Id, + Error: transport.BaseJSONRPCErrorInner{ + Code: -32603, // Internal error + Message: fmt.Sprintf("error handling request: %v", err), + }, + })) + if newErr != nil { + return fmt.Errorf("failed to write error sse message after encountering '%w' error when handling the request: %w", err, newErr) + } + + return fmt.Errorf("error handling request: %w", err) + + case transport.BaseMessageTypeJSONRPCNotificationType: + return fmt.Errorf("error handling notification: %w", err) + + default: + return fmt.Errorf("error handling unknown message type %s: %w", receivedMsg.Type, err) + } + } + + // Only send response for requests (not for notifications) + if receivedMsg.Type == transport.BaseMessageTypeJSONRPCRequestType { + err := t.writeMessageEvent(sseConn, response) + if err != nil { + if errors.Is(err, ErrConnectionRemoved) { + return err + } + + // Can try to write an error message event, but if that fails, we're out of options. + errorMsg := transport.NewBaseMessageError(&transport.BaseJSONRPCError{ + Jsonrpc: "2.0", + Id: receivedMsg.JsonRpcRequest.Id, + Error: transport.BaseJSONRPCErrorInner{ + Code: -32603, // Internal error + Message: fmt.Sprintf("failed to write message event: %v", err), + }, + }) + + newErr := t.writeMessageEvent(sseConn, errorMsg) + if newErr != nil { + return fmt.Errorf("failed to write error message event after encountering '%w' error: %w", err, newErr) + } + + return err + } + } + + return nil +} + +// writeMessageEvent writes a JSON-RPC message over SSE +func (t *ServerTransport) writeMessageEvent(sseConn *sseConnection, message *transport.BaseJsonRpcMessage) error { + data, err := json.Marshal(message) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + // Write the event header + if _, err := io.WriteString(sseConn.writer, "event: message\ndata: "); err != nil { + t.RemoveSSEConnection(sseConn.id) + return fmt.Errorf( + "failed to write message event for connection with id %d: %w. %w", + sseConn.id, + err, + ErrConnectionRemoved, + ) + } + + // Use the configured chunk size + for i := 0; i < len(data); i += t.chunkSize { + end := i + t.chunkSize + if end > len(data) { + end = len(data) + } + + if _, err := sseConn.writer.Write(data[i:end]); err != nil { + t.RemoveSSEConnection(sseConn.id) + return fmt.Errorf( + "failed to write message chunk for connection with id %d: %w. %w", + sseConn.id, + err, + ErrConnectionRemoved, + ) + } + sseConn.flusher.Flush() + } + + // Write the final newlines + if _, err := io.WriteString(sseConn.writer, "\n\n"); err != nil { + t.RemoveSSEConnection(sseConn.id) + return fmt.Errorf( + "failed to write message termination for connection with id %d: %w. %w", + sseConn.id, + err, + ErrConnectionRemoved, + ) + } + + sseConn.flusher.Flush() + return nil +} + +// Start implements Transport.Start +func (t *ServerTransport) Start(ctx context.Context) error { + // Nothing to do here as connections are established via HandleSSERequest + return nil +} + +// Send implements Transport.Send. +func (t *ServerTransport) Send(ctx context.Context, message *transport.BaseJsonRpcMessage) error { + key := message.JsonRpcResponse.Id + responseChannel := t.ResponseMap[int64(key)] + if responseChannel == nil { + return fmt.Errorf("no response channel found for key: %d", key) + } + responseChannel <- message + return nil +} + +// Close implements Transport.Close +func (t *ServerTransport) Close() error { + t.mu.Lock() + defer t.mu.Unlock() + + // Clear all SSE connections + t.sseConns = make(map[int64]*sseConnection) + + // Close the base transport + return t.Transport.Close() +} + +// SetMessageHandler implements Transport.SetMessageHandler +func (t *ServerTransport) SetMessageHandler(handler func(ctx context.Context, message *transport.BaseJsonRpcMessage)) { + t.Transport.SetMessageHandler(handler) +} + +// SetErrorHandler implements Transport.SetErrorHandler +func (t *ServerTransport) SetErrorHandler(handler func(error)) { + t.Transport.SetErrorHandler(handler) +} + +// SetCloseHandler implements Transport.SetCloseHandler +func (t *ServerTransport) SetCloseHandler(handler func()) { + t.Transport.SetCloseHandler(handler) +} + +func (t *ServerTransport) addSSEConnection(sseConn *sseConnection) { + t.mu.Lock() + defer t.mu.Unlock() + t.sseConns[sseConn.id] = sseConn +} + +func (t *ServerTransport) getSSEConnection(connID int64) (*sseConnection, bool) { + t.mu.RLock() + defer t.mu.RUnlock() + sseConn, ok := t.sseConns[connID] + return sseConn, ok +} + +// RemoveSSEConnection removes an SSE connection from the transport +func (t *ServerTransport) RemoveSSEConnection(connID int64) { + t.mu.Lock() + defer t.mu.Unlock() + delete(t.sseConns, connID) +} diff --git a/transport/sse/sse_server.go b/transport/sse/sse_server.go deleted file mode 100644 index e3fa67a..0000000 --- a/transport/sse/sse_server.go +++ /dev/null @@ -1,82 +0,0 @@ -package sse - -// -//import ( -// "context" -// "fmt" -// sse2 "github.com/metoro-io/mcp-golang/transport/sse/internal/sse" -// "io" -// "net/http" -//) -// -//// SSEServerTransport implements a server-side SSE transport -//type SSEServerTransport struct { -// transport *sse2.SSETransport -//} -// -//// NewSSEServerTransport creates a new SSE server transport -//func NewSSEServerTransport(endpoint string, w http.ResponseWriter) (*SSEServerTransport, error) { -// transport, err := sse2.NewSSETransport(endpoint, w) -// if err != nil { -// return nil, err -// } -// -// return &SSEServerTransport{ -// transport: transport, -// }, nil -//} -// -//// Start initializes the SSE connection -//func (s *SSEServerTransport) Start(ctx context.Context) error { -// return s.transport.Start(ctx) -//} -// -//// HandlePostMessage processes an incoming POST request containing a JSON-RPC message -//func (s *SSEServerTransport) HandlePostMessage(r *http.Request) error { -// if r.Method != http.MethodPost { -// return fmt.Errorf("method not allowed: %s", r.Method) -// } -// -// contentType := r.Header.Get("Content-Type") -// if contentType != "application/json" { -// return fmt.Errorf("unsupported Content type: %s", contentType) -// } -// -// body, err := io.ReadAll(io.LimitReader(r.Body, sse2.maxMessageSize)) -// if err != nil { -// return fmt.Errorf("failed to read request body: %w", err) -// } -// defer r.Body.Close() -// -// return s.transport.HandleMessage(body) -//} -// -//// Send sends a message over the SSE connection -//func (s *SSEServerTransport) Send(msg JSONRPCMessage) error { -// return s.transport.Send(msg) -//} -// -//// Close closes the SSE connection -//func (s *SSEServerTransport) Close() error { -// return s.transport.Close() -//} -// -//// SetCloseHandler sets the callback for when the connection is closed -//func (s *SSEServerTransport) SetCloseHandler(handler func()) { -// s.transport.SetCloseHandler(handler) -//} -// -//// SetErrorHandler sets the callback for when an error occurs -//func (s *SSEServerTransport) SetErrorHandler(handler func(error)) { -// s.transport.SetErrorHandler(handler) -//} -// -//// SetMessageHandler sets the callback for when a message is received -//func (s *SSEServerTransport) SetMessageHandler(handler func(JSONRPCMessage)) { -// s.transport.SetMessageHandler(handler) -//} -// -//// SessionID returns the unique session identifier for this transport -//func (s *SSEServerTransport) SessionID() string { -// return s.transport.SessionID() -//} diff --git a/transport/sse/sse_server_test.go b/transport/sse/sse_server_test.go deleted file mode 100644 index 101dff0..0000000 --- a/transport/sse/sse_server_test.go +++ /dev/null @@ -1,126 +0,0 @@ -package sse - -// -//import ( -// "bytes" -// "context" -// "encoding/json" -// "github.com/metoro-io/mcp-golang" -// "net/http" -// "net/http/httptest" -// "strings" -// "testing" -// -// "github.com/stretchr/testify/assert" -//) -// -//func TestSSEServerTransport(t *testing.T) { -// t.Run("basic message handling", func(t *testing.T) { -// w := httptest.NewRecorder() -// transport, err := NewSSEServerTransport("/messages", w) -// assert.NoError(t, err) -// -// var receivedMsg JSONRPCMessage -// transport.SetMessageHandler(func(msg JSONRPCMessage) { -// receivedMsg = msg -// }) -// -// ctx := context.Background() -// err = transport.Start(ctx) -// assert.NoError(t, err) -// -// // Verify SSE headers -// headers := w.Header() -// assert.Equal(t, "text/event-stream", headers.Get("Content-Type")) -// assert.Equal(t, "no-cache", headers.Get("Cache-Control")) -// assert.Equal(t, "keep-alive", headers.Get("Connection")) -// -// // Verify endpoint event was sent -// body := w.Body.String() -// assert.Contains(t, body, "event: endpoint") -// assert.Contains(t, body, "/messages?sessionId=") -// -// // Test message handling -// msg := JSONRPCRequest{ -// Jsonrpc: "2.0", -// Method: "test", -// Id: 1, -// } -// msgBytes, err := json.Marshal(msg) -// assert.NoError(t, err) -// -// httpReq := httptest.NewRequest(http.MethodPost, "/messages", bytes.NewReader(msgBytes)) -// httpReq.Header.Set("Content-Type", "application/json") -// err = transport.HandlePostMessage(httpReq) -// assert.NoError(t, err) -// -// // Verify received message -// rpcReq, ok := receivedMsg.(*JSONRPCRequest) -// assert.True(t, ok) -// assert.Equal(t, "2.0", rpcReq.Jsonrpc) -// assert.Equal(t, mcp.RequestId(1), rpcReq.Id) -// -// err = transport.Close() -// assert.NoError(t, err) -// }) -// -// t.Run("send message", func(t *testing.T) { -// w := httptest.NewRecorder() -// transport, err := NewSSEServerTransport("/messages", w) -// assert.NoError(t, err) -// -// ctx := context.Background() -// err = transport.Start(ctx) -// assert.NoError(t, err) -// -// msg := JSONRPCResponse{ -// Jsonrpc: "2.0", -// Result: Result{AdditionalProperties: map[string]interface{}{"status": "ok"}}, -// Id: 1, -// } -// -// err = transport.Send(msg) -// assert.NoError(t, err) -// -// // Verify output contains the message -// body := w.Body.String() -// assert.Contains(t, body, `event: message`) -// assert.Contains(t, body, `"result":{"AdditionalProperties":{"status":"ok"}}`) -// }) -// -// t.Run("error handling", func(t *testing.T) { -// w := httptest.NewRecorder() -// transport, err := NewSSEServerTransport("/messages", w) -// assert.NoError(t, err) -// -// var receivedErr error -// transport.SetErrorHandler(func(err error) { -// receivedErr = err -// }) -// -// ctx := context.Background() -// err = transport.Start(ctx) -// assert.NoError(t, err) -// -// // Test invalid JSON -// req := httptest.NewRequest(http.MethodPost, "/messages", strings.NewReader("invalid json")) -// req.Header.Set("Content-Type", "application/json") -// err = transport.HandlePostMessage(req) -// assert.Error(t, err) -// assert.NotNil(t, receivedErr) -// assert.Contains(t, receivedErr.Error(), "invalid") -// -// // Test invalid Content type -// req = httptest.NewRequest(http.MethodPost, "/messages", strings.NewReader("{}")) -// req.Header.Set("Content-Type", "text/plain") -// err = transport.HandlePostMessage(req) -// assert.Error(t, err) -// assert.Contains(t, err.Error(), "unsupported Content type") -// -// // Test invalid method -// req = httptest.NewRequest(http.MethodGet, "/messages", nil) -// err = transport.HandlePostMessage(req) -// assert.Error(t, err) -// assert.Contains(t, err.Error(), "method not allowed") -// }) -//} diff --git a/transport/stdio/internal/stdio/stdio.go b/transport/stdio/internal/stdio/stdio.go index 9b85698..3f08aa2 100644 --- a/transport/stdio/internal/stdio/stdio.go +++ b/transport/stdio/internal/stdio/stdio.go @@ -60,7 +60,7 @@ package stdio import ( "encoding/json" "errors" - "github.com/metoro-io/mcp-golang/transport" + "github.com/rvoh-emccaleb/mcp-golang/transport" "sync" ) diff --git a/transport/stdio/internal/stdio/stdio_test.go b/transport/stdio/internal/stdio/stdio_test.go index 7412d5b..5106b71 100644 --- a/transport/stdio/internal/stdio/stdio_test.go +++ b/transport/stdio/internal/stdio/stdio_test.go @@ -1,7 +1,7 @@ package stdio import ( - "github.com/metoro-io/mcp-golang/transport" + "github.com/rvoh-emccaleb/mcp-golang/transport" "testing" "github.com/stretchr/testify/assert" diff --git a/transport/stdio/stdio_server.go b/transport/stdio/stdio_server.go index 36d52f2..aa839a4 100644 --- a/transport/stdio/stdio_server.go +++ b/transport/stdio/stdio_server.go @@ -9,8 +9,8 @@ import ( "os" "sync" - "github.com/metoro-io/mcp-golang/transport" - "github.com/metoro-io/mcp-golang/transport/stdio/internal/stdio" + "github.com/rvoh-emccaleb/mcp-golang/transport" + "github.com/rvoh-emccaleb/mcp-golang/transport/stdio/internal/stdio" ) // StdioServerTransport implements server-side transport for stdio communication diff --git a/transport/stdio/stdio_server_test.go b/transport/stdio/stdio_server_test.go index d8674ba..25243b3 100644 --- a/transport/stdio/stdio_server_test.go +++ b/transport/stdio/stdio_server_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/metoro-io/mcp-golang/transport" + "github.com/rvoh-emccaleb/mcp-golang/transport" "github.com/stretchr/testify/assert" ) diff --git a/transport/types.go b/transport/types.go index d2627f8..7f615f7 100644 --- a/transport/types.go +++ b/transport/types.go @@ -3,6 +3,7 @@ package transport import ( "encoding/json" "errors" + "fmt" ) type JSONRPCMessage interface{} @@ -193,6 +194,43 @@ func (m *BaseJsonRpcMessage) MarshalJSON() ([]byte, error) { } } +// UnmarshalJSON implements json.Unmarshaler for BaseJsonRpcMessage +func (m *BaseJsonRpcMessage) UnmarshalJSON(data []byte) error { + // Try as a request + var request BaseJSONRPCRequest + if err := json.Unmarshal(data, &request); err == nil { + m.Type = BaseMessageTypeJSONRPCRequestType + m.JsonRpcRequest = &request + return nil + } + + // Try as a notification + var notification BaseJSONRPCNotification + if err := json.Unmarshal(data, ¬ification); err == nil { + m.Type = BaseMessageTypeJSONRPCNotificationType + m.JsonRpcNotification = ¬ification + return nil + } + + // Try as a response + var response BaseJSONRPCResponse + if err := json.Unmarshal(data, &response); err == nil { + m.Type = BaseMessageTypeJSONRPCResponseType + m.JsonRpcResponse = &response + return nil + } + + // Try as an error + var errorResponse BaseJSONRPCError + if err := json.Unmarshal(data, &errorResponse); err == nil { + m.Type = BaseMessageTypeJSONRPCErrorType + m.JsonRpcError = &errorResponse + return nil + } + + return fmt.Errorf("failed to unmarshal JSON-RPC message, unrecognized type") +} + func NewBaseMessageNotification(notification *BaseJSONRPCNotification) *BaseJsonRpcMessage { return &BaseJsonRpcMessage{ Type: BaseMessageTypeJSONRPCNotificationType,