diff --git a/README.md b/README.md
index 74f4ec5..9590d80 100644
--- a/README.md
+++ b/README.md
@@ -4,16 +4,16 @@
-
-
-
-
-
-
-
-[](https://pkg.go.dev/github.com/metoro-io/mcp-golang)
-[](https://goreportcard.com/report/github.com/metoro-io/mcp-golang)
-
+
+
+
+
+
+
+
+[](https://pkg.go.dev/github.com/rvoh-emccaleb/mcp-golang)
+[](https://goreportcard.com/report/github.com/rvoh-emccaleb/mcp-golang)
+
@@ -37,7 +37,7 @@ Docs at [https://mcpgolang.com](https://mcpgolang.com)
## Example Usage
-Install with `go get github.com/metoro-io/mcp-golang`
+Install with `go get github.com/rvoh-emccaleb/mcp-golang`
### Server Example
@@ -46,8 +46,8 @@ package main
import (
"fmt"
- "github.com/metoro-io/mcp-golang"
- "github.com/metoro-io/mcp-golang/transport/stdio"
+ "github.com/rvoh-emccaleb/mcp-golang"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/stdio"
)
// Tool arguments are just structs, annotated with jsonschema tags
@@ -121,8 +121,8 @@ package main
import (
"context"
"log"
- mcp "github.com/metoro-io/mcp-golang"
- "github.com/metoro-io/mcp-golang/transport/stdio"
+ mcp "github.com/rvoh-emccaleb/mcp-golang"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/stdio"
)
// Define type-safe arguments
diff --git a/client.go b/client.go
index 672e217..5ea27e0 100644
--- a/client.go
+++ b/client.go
@@ -4,9 +4,9 @@ import (
"context"
"encoding/json"
- "github.com/metoro-io/mcp-golang/internal/protocol"
- "github.com/metoro-io/mcp-golang/transport"
"github.com/pkg/errors"
+ "github.com/rvoh-emccaleb/mcp-golang/internal/protocol"
+ "github.com/rvoh-emccaleb/mcp-golang/transport"
)
// Client represents an MCP client that can connect to and interact with MCP servers
@@ -25,8 +25,21 @@ func NewClient(transport transport.Transport) *Client {
}
}
+// A bit loosey goosey, but it works for now.
+// See: https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/
+type InitializeRequestParams struct {
+ ProtocolVersion string `json:"protocolVersion"`
+ ClientInfo InitializeRequestClientInfo `json:"clientInfo"`
+ Capabilities map[string]interface{} `json:"capabilities"`
+}
+
+type InitializeRequestClientInfo struct {
+ Name string `json:"name"`
+ Version string `json:"version"`
+}
+
// Initialize connects to the server and retrieves its capabilities
-func (c *Client) Initialize(ctx context.Context) (*InitializeResponse, error) {
+func (c *Client) Initialize(ctx context.Context, params *InitializeRequestParams) (*InitializeResponse, error) {
if c.initialized {
return nil, errors.New("client already initialized")
}
@@ -36,8 +49,8 @@ func (c *Client) Initialize(ctx context.Context) (*InitializeResponse, error) {
return nil, errors.Wrap(err, "failed to connect transport")
}
- // Make initialize request to server
- response, err := c.protocol.Request(ctx, "initialize", map[string]interface{}{}, nil)
+ // Begin initialization handshake with server
+ response, err := c.protocol.Request(ctx, "initialize", params, nil)
if err != nil {
return nil, errors.Wrap(err, "failed to initialize")
}
@@ -53,8 +66,15 @@ func (c *Client) Initialize(ctx context.Context) (*InitializeResponse, error) {
return nil, errors.Wrap(err, "failed to unmarshal initialize response")
}
+ // Finish initialization handshake with server
+ err = c.protocol.Notification("notifications/initialized", map[string]interface{}{})
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to send initialized notification")
+ }
+
c.capabilities = &initResult.Capabilities
c.initialized = true
+
return &initResult, nil
}
@@ -235,17 +255,13 @@ func (c *Client) ReadResource(ctx context.Context, uri string) (*ResourceRespons
return nil, errors.New("invalid response type")
}
- var resourceResponse resourceResponseSent
+ var resourceResponse ResourceResponse
err = json.Unmarshal(responseBytes, &resourceResponse)
if err != nil {
return nil, errors.Wrap(err, "failed to unmarshal resource response")
}
- if resourceResponse.Error != nil {
- return nil, resourceResponse.Error
- }
-
- return resourceResponse.Response, nil
+ return &resourceResponse, nil
}
// Ping sends a ping request to the server to check connectivity
@@ -266,3 +282,13 @@ func (c *Client) Ping(ctx context.Context) error {
func (c *Client) GetCapabilities() *ServerCapabilities {
return c.capabilities
}
+
+// Close cleans up resources used by the client, including the protocol and transport layers.
+// It should be called when the client is no longer needed.
+func (c *Client) Close() error {
+ if err := c.protocol.Close(); err != nil {
+ return errors.Wrap(err, "failed to close protocol")
+ }
+
+ return nil
+}
diff --git a/content_api.go b/content_api.go
index b111593..94e9989 100644
--- a/content_api.go
+++ b/content_api.go
@@ -85,14 +85,51 @@ type EmbeddedResource struct {
// Custom JSON marshaling for EmbeddedResource
func (c EmbeddedResource) MarshalJSON() ([]byte, error) {
+ type wrapper struct {
+ Type string `json:"embeddedResourceType"`
+ *TextResourceContents
+ *BlobResourceContents
+ }
+
+ w := wrapper{Type: string(c.EmbeddedResourceType)}
+
switch c.EmbeddedResourceType {
case embeddedResourceTypeBlob:
- return json.Marshal(c.BlobResourceContents)
+ w.BlobResourceContents = c.BlobResourceContents
case embeddedResourceTypeText:
- return json.Marshal(c.TextResourceContents)
+ w.TextResourceContents = c.TextResourceContents
default:
return nil, fmt.Errorf("unknown embedded resource type: %s", c.EmbeddedResourceType)
}
+
+ return json.Marshal(w)
+}
+
+func (c *EmbeddedResource) UnmarshalJSON(data []byte) error {
+ var wrapper struct {
+ Type string `json:"embeddedResourceType"`
+ Raw json.RawMessage `json:"-"`
+ }
+
+ wrapper.Raw = data
+ if err := json.Unmarshal(data, &wrapper); err != nil {
+ return err
+ }
+
+ c.EmbeddedResourceType = embeddedResourceType(wrapper.Type)
+
+ switch c.EmbeddedResourceType {
+ case embeddedResourceTypeText:
+ c.TextResourceContents = new(TextResourceContents)
+ return json.Unmarshal(wrapper.Raw, c.TextResourceContents)
+
+ case embeddedResourceTypeBlob:
+ c.BlobResourceContents = new(BlobResourceContents)
+ return json.Unmarshal(wrapper.Raw, c.BlobResourceContents)
+
+ default:
+ return fmt.Errorf("unknown embedded resource type: %s", wrapper.Type)
+ }
}
type ContentType string
diff --git a/docs/change-notifications.mdx b/docs/change-notifications.mdx
index 82fd09d..5c9d8bd 100644
--- a/docs/change-notifications.mdx
+++ b/docs/change-notifications.mdx
@@ -24,8 +24,8 @@ package main
import (
"fmt"
- mcp_golang "github.com/metoro-io/mcp-golang"
- "github.com/metoro-io/mcp-golang/transport/stdio"
+ mcp_golang "github.com/rvoh-emccaleb/mcp-golang"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/stdio"
"time"
)
diff --git a/docs/client.mdx b/docs/client.mdx
index 1a175c1..6e45bb9 100644
--- a/docs/client.mdx
+++ b/docs/client.mdx
@@ -12,7 +12,7 @@ The MCP client provides a simple and intuitive way to interact with MCP servers.
Add the MCP Golang package to your project:
```bash
-go get github.com/metoro-io/mcp-golang
+go get github.com/rvoh-emccaleb/mcp-golang
```
## Basic Usage
@@ -21,8 +21,8 @@ Here's a simple example of creating and initializing an MCP client:
```go
import (
- mcp "github.com/metoro-io/mcp-golang"
- "github.com/metoro-io/mcp-golang/transport/stdio"
+ mcp "github.com/rvoh-emccaleb/mcp-golang"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/stdio"
)
// Create a transport (stdio in this example)
@@ -197,7 +197,7 @@ if err != nil {
## Complete Example
-For a complete working example, check out our [example client implementation](https://github.com/metoro-io/mcp-golang/tree/main/examples/client).
+For a complete working example, check out our [example client implementation](https://github.com/rvoh-emccaleb/mcp-golang/tree/main/examples/client).
## Transport Options
diff --git a/docs/contributing.mdx b/docs/contributing.mdx
index 0bbc722..425b44c 100644
--- a/docs/contributing.mdx
+++ b/docs/contributing.mdx
@@ -19,7 +19,7 @@ To set up your development environment, follow these steps:
1. Clone the repository:
```bash
-git clone https://github.com/metoro-io/mcp-golang.git
+git clone https://github.com/rvoh-emccaleb/mcp-golang.git
cd mcp-golang
```
@@ -210,6 +210,6 @@ When your PR merges into the main branch, it will be deployed automatically.
## Getting Help
-- Check existing [GitHub issues](https://github.com/metoro-io/mcp-golang/issues)
+- Check existing [GitHub issues](https://github.com/rvoh-emccaleb/mcp-golang/issues)
- Join our [Discord community](https://discord.gg/33saRwE3pT)
- Read the [Model Context Protocol specification](https://modelcontextprotocol.io/)
diff --git a/docs/mint.json b/docs/mint.json
index 67c435c..7947bfa 100644
--- a/docs/mint.json
+++ b/docs/mint.json
@@ -21,13 +21,13 @@
},
"topbarCtaButton": {
"name": "Github Repo",
- "url": "https://github.com/metoro-io/mcp-golang"
+ "url": "https://github.com/rvoh-emccaleb/mcp-golang"
},
"anchors": [
{
"name": "Github",
"icon": "github",
- "url": "https://github.com/metoro-io/mcp-golang"
+ "url": "https://github.com/rvoh-emccaleb/mcp-golang"
},
{
"name": "Discord Community",
@@ -62,7 +62,7 @@
],
"footerSocials": {
"x": "https://x.com/metoro_ai",
- "github": "https://github.com/metoro-io/mcp-golang",
+ "github": "https://github.com/rvoh-emccaleb/mcp-golang",
"website": "https://https://metoro.io/"
}
}
diff --git a/docs/quickstart.mdx b/docs/quickstart.mdx
index 03bdb40..a5148e3 100644
--- a/docs/quickstart.mdx
+++ b/docs/quickstart.mdx
@@ -8,7 +8,7 @@ description: 'Set up your first mcp-golang server'
First, add mcp-golang to your project:
```bash
-go get github.com/metoro-io/mcp-golang
+go get github.com/rvoh-emccaleb/mcp-golang
```
## Basic Usage
@@ -20,8 +20,8 @@ package main
import (
"fmt"
- "github.com/metoro-io/mcp-golang"
- "github.com/metoro-io/mcp-golang/transport/stdio"
+ "github.com/rvoh-emccaleb/mcp-golang"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/stdio"
)
type Content struct {
@@ -93,8 +93,8 @@ package main
import (
"context"
"log"
- "github.com/metoro-io/mcp-golang"
- "github.com/metoro-io/mcp-golang/transport/http"
+ "github.com/rvoh-emccaleb/mcp-golang"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/http"
)
func main() {
@@ -122,8 +122,8 @@ package main
import (
"github.com/gin-gonic/gin"
- "github.com/metoro-io/mcp-golang"
- "github.com/metoro-io/mcp-golang/transport/http"
+ "github.com/rvoh-emccaleb/mcp-golang"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/http"
)
func main() {
@@ -155,8 +155,8 @@ package main
import (
"context"
"log"
- "github.com/metoro-io/mcp-golang"
- "github.com/metoro-io/mcp-golang/transport/http"
+ "github.com/rvoh-emccaleb/mcp-golang"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/http"
)
func main() {
@@ -183,4 +183,4 @@ func main() {
- If you're interested in contributing to mcp-golang, check out [Development Guide](/development) for more detailed information
- Join our [Discord Community](https://discord.gg/33saRwE3pT) for support
-- Visit our [GitHub Repository](https://github.com/metoro-io/mcp-golang) to contribute
+- Visit our [GitHub Repository](https://github.com/rvoh-emccaleb/mcp-golang) to contribute
diff --git a/docs/tools.mdx b/docs/tools.mdx
index 3854884..328f5f6 100644
--- a/docs/tools.mdx
+++ b/docs/tools.mdx
@@ -17,8 +17,8 @@ package main
import (
"fmt"
- "github.com/metoro-io/mcp-golang"
- "github.com/metoro-io/mcp-golang/transport/stdio"
+ "github.com/rvoh-emccaleb/mcp-golang"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/stdio"
)
type HelloArguments struct {
diff --git a/examples/basic_tool_server/basic_tool_server.go b/examples/basic_tool_server/basic_tool_server.go
index b93e297..9a5a5e7 100644
--- a/examples/basic_tool_server/basic_tool_server.go
+++ b/examples/basic_tool_server/basic_tool_server.go
@@ -2,8 +2,8 @@ package main
import (
"fmt"
- "github.com/metoro-io/mcp-golang"
- "github.com/metoro-io/mcp-golang/transport/stdio"
+ "github.com/rvoh-emccaleb/mcp-golang"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/stdio"
)
type Content struct {
diff --git a/examples/client/main.go b/examples/client/main.go
index e37efd9..375217c 100644
--- a/examples/client/main.go
+++ b/examples/client/main.go
@@ -6,8 +6,12 @@ import (
"context"
- mcp_golang "github.com/metoro-io/mcp-golang"
- "github.com/metoro-io/mcp-golang/transport/stdio"
+ mcp_golang "github.com/rvoh-emccaleb/mcp-golang"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/stdio"
+)
+
+const (
+ ProtocolVersion = "2024-11-05"
)
func main() {
@@ -17,6 +21,7 @@ func main() {
if err != nil {
log.Fatalf("Failed to get stdin pipe: %v", err)
}
+
stdout, err := cmd.StdoutPipe()
if err != nil {
log.Fatalf("Failed to get stdout pipe: %v", err)
@@ -25,12 +30,23 @@ func main() {
if err := cmd.Start(); err != nil {
log.Fatalf("Failed to start server: %v", err)
}
+
defer cmd.Process.Kill()
clientTransport := stdio.NewStdioServerTransportWithIO(stdout, stdin)
client := mcp_golang.NewClient(clientTransport)
- if _, err := client.Initialize(context.Background()); err != nil {
+ if _, err := client.Initialize(
+ context.Background(),
+ &mcp_golang.InitializeRequestParams{
+ ClientInfo: mcp_golang.InitializeRequestClientInfo{
+ Name: "example-client",
+ Version: "0.1.0",
+ },
+ ProtocolVersion: ProtocolVersion,
+ Capabilities: nil,
+ },
+ ); err != nil {
log.Fatalf("Failed to initialize client: %v", err)
}
diff --git a/examples/client/server/main.go b/examples/client/server/main.go
index 87caa8b..fbf7b71 100644
--- a/examples/client/server/main.go
+++ b/examples/client/server/main.go
@@ -5,8 +5,8 @@ import (
"strings"
"time"
- mcp "github.com/metoro-io/mcp-golang"
- "github.com/metoro-io/mcp-golang/transport/stdio"
+ mcp "github.com/rvoh-emccaleb/mcp-golang"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/stdio"
)
// HelloArgs represents the arguments for the hello tool
diff --git a/examples/get_weather_tool_server/get_weather_tool_server.go b/examples/get_weather_tool_server/get_weather_tool_server.go
index 096c542..bdbf952 100644
--- a/examples/get_weather_tool_server/get_weather_tool_server.go
+++ b/examples/get_weather_tool_server/get_weather_tool_server.go
@@ -2,8 +2,8 @@ package main
import (
"fmt"
- mcp_golang "github.com/metoro-io/mcp-golang"
- "github.com/metoro-io/mcp-golang/transport/stdio"
+ mcp_golang "github.com/rvoh-emccaleb/mcp-golang"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/stdio"
"io"
"net/http"
)
diff --git a/examples/gin_example/main.go b/examples/gin_example/main.go
index 2753592..c332d0e 100644
--- a/examples/gin_example/main.go
+++ b/examples/gin_example/main.go
@@ -7,8 +7,8 @@ import (
"time"
"github.com/gin-gonic/gin"
- mcp_golang "github.com/metoro-io/mcp-golang"
- "github.com/metoro-io/mcp-golang/transport/http"
+ mcp_golang "github.com/rvoh-emccaleb/mcp-golang"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/http"
)
// TimeArgs defines the arguments for the time tool
diff --git a/examples/http_example/auth_example_client/main.go b/examples/http_example/auth_example_client/main.go
index 71e5830..1de55eb 100644
--- a/examples/http_example/auth_example_client/main.go
+++ b/examples/http_example/auth_example_client/main.go
@@ -3,15 +3,20 @@ package main
import (
"context"
"log"
+ "time"
"github.com/davecgh/go-spew/spew"
- mcp_golang "github.com/metoro-io/mcp-golang"
- "github.com/metoro-io/mcp-golang/transport/http"
+ mcp_golang "github.com/rvoh-emccaleb/mcp-golang"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/http"
+)
+
+const (
+ ProtocolVersion = "2024-11-05"
)
func main() {
// Create an HTTP transport that connects to the server
- transport := http.NewHTTPClientTransport("/mcp")
+ transport := http.NewHTTPClientTransport("/mcp", 1*time.Millisecond)
transport.WithBaseURL("http://localhost:8080/api/v1")
// Public metoro token - not a leak
transport.WithHeader("Authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJjdXN0b21lcklkIjoiOThlZDU1M2QtYzY4ZC00MDRhLWFhZjItNDM2ODllNWJiMGUzIiwiZW1haWwiOiJ0ZXN0QGNocmlzYmF0dGFyYmVlLmNvbSIsImV4cCI6MTgyMTI0NzIzN30.QeFzKsP1yO16pVol0mkAdt7qhJf6nTqBoqXqdWawBdE")
@@ -20,7 +25,17 @@ func main() {
client := mcp_golang.NewClient(transport)
// Initialize the client
- if resp, err := client.Initialize(context.Background()); err != nil {
+ if resp, err := client.Initialize(
+ context.Background(),
+ &mcp_golang.InitializeRequestParams{
+ ClientInfo: mcp_golang.InitializeRequestClientInfo{
+ Name: "example-client",
+ Version: "0.1.0",
+ },
+ ProtocolVersion: ProtocolVersion,
+ Capabilities: nil,
+ },
+ ); err != nil {
log.Fatalf("Failed to initialize client: %v", err)
} else {
log.Printf("Initialized client: %v", spew.Sdump(resp))
diff --git a/examples/http_example/client/main.go b/examples/http_example/client/main.go
index 622bfdd..fb307a1 100644
--- a/examples/http_example/client/main.go
+++ b/examples/http_example/client/main.go
@@ -6,21 +6,36 @@ import (
"time"
"github.com/davecgh/go-spew/spew"
- mcp_golang "github.com/metoro-io/mcp-golang"
- "github.com/metoro-io/mcp-golang/transport/http"
+ mcp_golang "github.com/rvoh-emccaleb/mcp-golang"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/http"
+)
+
+const (
+ ProtocolVersion = "2024-11-05"
)
func main() {
// Create an HTTP transport that connects to the server
- transport := http.NewHTTPClientTransport("/mcp")
+ transport := http.NewHTTPClientTransport("/mcp", 1*time.Millisecond)
transport.WithBaseURL("http://localhost:8081")
// Create a new client with the transport
client := mcp_golang.NewClient(transport)
// Initialize the client
- if resp, err := client.Initialize(context.Background()); err != nil {
+ if resp, err := client.Initialize(
+ context.Background(),
+ &mcp_golang.InitializeRequestParams{
+ ClientInfo: mcp_golang.InitializeRequestClientInfo{
+ Name: "example-client",
+ Version: "0.1.0",
+ },
+ ProtocolVersion: ProtocolVersion,
+ Capabilities: nil,
+ },
+ ); err != nil {
log.Fatalf("Failed to initialize client: %v", err)
+
} else {
log.Printf("Initialized client: %v", spew.Sdump(resp))
}
diff --git a/examples/http_example/server/main.go b/examples/http_example/server/main.go
index 563180c..d817a57 100644
--- a/examples/http_example/server/main.go
+++ b/examples/http_example/server/main.go
@@ -4,8 +4,8 @@ import (
"log"
"time"
- mcp_golang "github.com/metoro-io/mcp-golang"
- "github.com/metoro-io/mcp-golang/transport/http"
+ mcp_golang "github.com/rvoh-emccaleb/mcp-golang"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/http"
)
// TimeArgs defines the arguments for the time tool
diff --git a/examples/pagination_example/pagination_example.go b/examples/pagination_example/pagination_example.go
index 4087929..cba6d19 100644
--- a/examples/pagination_example/pagination_example.go
+++ b/examples/pagination_example/pagination_example.go
@@ -2,8 +2,8 @@ package main
import (
"fmt"
- mcp_golang "github.com/metoro-io/mcp-golang"
- "github.com/metoro-io/mcp-golang/transport/stdio"
+ mcp_golang "github.com/rvoh-emccaleb/mcp-golang"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/stdio"
)
// Arguments for our tools
diff --git a/examples/readme_server/readme_server.go b/examples/readme_server/readme_server.go
index 7e20743..741b34d 100644
--- a/examples/readme_server/readme_server.go
+++ b/examples/readme_server/readme_server.go
@@ -2,8 +2,8 @@ package main
import (
"fmt"
- "github.com/metoro-io/mcp-golang"
- "github.com/metoro-io/mcp-golang/transport/stdio"
+ "github.com/rvoh-emccaleb/mcp-golang"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/stdio"
)
// Tool arguments are just structs, annotated with jsonschema tags
diff --git a/examples/server/main.go b/examples/server/main.go
index c12522b..20504b4 100644
--- a/examples/server/main.go
+++ b/examples/server/main.go
@@ -3,8 +3,8 @@ package main
import (
"fmt"
- mcp "github.com/metoro-io/mcp-golang"
- "github.com/metoro-io/mcp-golang/transport/stdio"
+ mcp "github.com/rvoh-emccaleb/mcp-golang"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/stdio"
)
// HelloArgs represents the arguments for the hello tool
diff --git a/examples/simple_tool_docs/simple_tool_docs.go b/examples/simple_tool_docs/simple_tool_docs.go
index 881aec4..ebb2c2b 100644
--- a/examples/simple_tool_docs/simple_tool_docs.go
+++ b/examples/simple_tool_docs/simple_tool_docs.go
@@ -2,8 +2,8 @@ package main
import (
"fmt"
- mcp_golang "github.com/metoro-io/mcp-golang"
- "github.com/metoro-io/mcp-golang/transport/stdio"
+ mcp_golang "github.com/rvoh-emccaleb/mcp-golang"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/stdio"
)
type HelloArguments struct {
diff --git a/examples/sse_example/main.go b/examples/sse_example/main.go
new file mode 100644
index 0000000..7401956
--- /dev/null
+++ b/examples/sse_example/main.go
@@ -0,0 +1,253 @@
+package main
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/davecgh/go-spew/spew"
+ mcp_golang "github.com/rvoh-emccaleb/mcp-golang"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/sse"
+)
+
+// TimeArgs defines the arguments for the time tool
+type TimeArgs struct {
+ Format string `json:"format" jsonschema:"description=The time format to use"`
+}
+
+const (
+ baseEndpoint = "/mcp/sse"
+ protocolVersion = "2024-11-05"
+ serverPort = 8083
+)
+
+func main() {
+ // Add a root context that we can cancel
+ rootCtx, rootCancel := context.WithCancel(context.Background())
+ defer rootCancel()
+
+ sseTransport := sse.NewServerTransport(baseEndpoint)
+
+ server := mcp_golang.NewServer(
+ sseTransport,
+ mcp_golang.WithName("mcp-golang-sse-example"),
+ mcp_golang.WithVersion("0.0.1"),
+ )
+
+ // Register a simple tool
+ err := server.RegisterTool("time", "Returns the current time in the specified format", func(ctx context.Context, args TimeArgs) (*mcp_golang.ToolResponse, error) {
+ format := args.Format
+ if format == "" {
+ format = time.RFC3339
+ }
+ return mcp_golang.NewToolResponse(mcp_golang.NewTextContent(time.Now().Format(format))), nil
+ })
+ if err != nil {
+ panic(err)
+ }
+
+ // Handler for establishing a new SSE connection
+ http.HandleFunc(baseEndpoint, func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodGet {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ // Create a context that's cancelled either by client disconnect or server shutdown
+ ctx, cancel := context.WithCancel(r.Context())
+ go func() {
+ select {
+ case <-rootCtx.Done():
+ cancel()
+ case <-r.Context().Done():
+ cancel()
+ }
+ }()
+ defer cancel()
+
+ connID, err := sseTransport.HandleSSEConnInitialize(w)
+ if err != nil {
+ log.Printf("error initializing sse connection: %v", err)
+ return
+ }
+
+ log.Printf("New SSE connection established with ID: %d", connID)
+ start := time.Now()
+
+ <-ctx.Done() // Wait for either client disconnect or server shutdown
+
+ log.Printf("SSE connection closed after %v for connection with ID: %d", time.Since(start), connID)
+ sseTransport.RemoveSSEConnection(connID)
+ })
+
+ // Handler for receiving JSON-RPC messages
+ http.HandleFunc(baseEndpoint+"/", func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+ log.Printf("Received POST request to %s", r.URL.Path)
+
+ // Extract connection ID from path
+ parts := strings.Split(r.URL.Path, "/")
+ if len(parts) < 3 {
+ log.Printf("Error: missing connection ID in path")
+ http.Error(w, "missing connection ID in path", http.StatusBadRequest)
+ return
+ }
+ connIDStr := parts[len(parts)-1]
+
+ connID, err := strconv.ParseInt(connIDStr, 10, 64)
+ if err != nil {
+ log.Printf("Error: invalid connection ID format: %s", connIDStr)
+ http.Error(w, fmt.Sprintf("invalid connection ID format: %s", connIDStr), http.StatusBadRequest)
+ return
+ }
+
+ // Read and log the request body
+ body, err := io.ReadAll(r.Body)
+ if err != nil {
+ log.Printf("Error reading request body: %v", err)
+ http.Error(w, "failed to read request body", http.StatusBadRequest)
+ return
+ }
+ log.Printf("Request body: %s", string(body))
+ r.Body = io.NopCloser(bytes.NewBuffer(body))
+
+ // Handle all MCP messages (e.g. initialize notification, list tools request, etc.)
+ if err := sseTransport.HandleMCPMessage(w, r, connID); err != nil {
+ log.Printf("error handling MCP message: %v", err)
+ return
+ }
+ })
+
+ // Sets the MCP protocol message handlers on our transport
+ // e.g. initialize, tools/list, tools/call, etc.
+ err = server.Serve()
+ if err != nil {
+ log.Fatalf("Server error: %v", err)
+ }
+
+ // Start the HTTP server
+ log.Printf("Starting HTTP server with SSE on :%d...", serverPort)
+ httpServer := &http.Server{
+ Addr: fmt.Sprintf(":%d", serverPort),
+ Handler: nil,
+
+ // Timeouts
+ ReadTimeout: 10 * time.Second, // For initial connection and POST requests
+ WriteTimeout: 0, // No timeout for SSE writes
+ MaxHeaderBytes: 1 << 20, // 1MB
+ ReadHeaderTimeout: 10 * time.Second, // For initial connection and POST requests
+ IdleTimeout: 0, // No timeout for SSE connections
+ }
+
+ done := make(chan struct{})
+
+ go func() {
+ err := httpServer.ListenAndServe()
+ if err != nil && err != http.ErrServerClosed {
+ log.Printf("Server error: %v", err)
+ }
+
+ close(done)
+ }()
+
+ time.Sleep(1 * time.Second)
+
+ // Create an HTTP transport that connects to the server
+ transport := sse.NewSSEClientTransport(baseEndpoint, 1*time.Millisecond)
+ transport.WithBaseURL(fmt.Sprintf("http://localhost:%d", serverPort))
+
+ // Create a new client with the transport
+ client := mcp_golang.NewClient(transport)
+
+ // Initialize the client
+ if resp, err := client.Initialize(
+ context.Background(),
+ &mcp_golang.InitializeRequestParams{
+ ClientInfo: mcp_golang.InitializeRequestClientInfo{
+ Name: "example-client",
+ Version: "0.1.0",
+ },
+ ProtocolVersion: protocolVersion,
+ Capabilities: nil,
+ },
+ ); err != nil {
+ log.Fatalf("Failed to initialize client: %v", err)
+
+ } else {
+ log.Printf("Initialized client: %v", spew.Sdump(resp))
+ }
+
+ // List available tools
+ tools, err := client.ListTools(context.Background(), nil)
+ if err != nil {
+ log.Fatalf("Failed to list tools: %v", err)
+ }
+
+ log.Println("Available Tools:")
+ for _, tool := range tools.Tools {
+ desc := ""
+ if tool.Description != nil {
+ desc = *tool.Description
+ }
+ log.Printf("Tool: %s. Description: %s", tool.Name, desc)
+ }
+
+ // Call the time tool with different formats
+ formats := []string{
+ time.RFC3339,
+ "2006-01-02 15:04:05",
+ "Mon, 02 Jan 2006",
+ }
+
+ for _, format := range formats {
+ args := map[string]interface{}{
+ "format": format,
+ }
+
+ response, err := client.CallTool(context.Background(), "time", args)
+ if err != nil {
+ log.Printf("Failed to call time tool: %v", err)
+ continue
+ }
+
+ if len(response.Content) > 0 && response.Content[0].TextContent != nil {
+ log.Printf("Time in format %q: %s", format, response.Content[0].TextContent.Text)
+ }
+ }
+
+ // When shutting down:
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ // 1. First cancel all SSE connections and prevent new ones
+ log.Println("Cancelling all SSE connections...")
+ rootCancel()
+
+ // 2. Close the transport (which should prevent new SSE connections)
+ log.Println("Closing SSE transport...")
+ if err := sseTransport.Close(); err != nil {
+ log.Printf("Error closing SSE transport: %v", err)
+ }
+
+ // 3. Finally shutdown the HTTP server
+ log.Println("Shutting down HTTP server...")
+ if err := httpServer.Shutdown(ctx); err != nil {
+ log.Printf("Error during server shutdown: %v", err)
+ }
+
+ select {
+ case <-done:
+ log.Println("Server shutdown completed")
+ case <-time.After(10 * time.Second):
+ log.Println("Server shutdown timed out")
+ }
+}
diff --git a/examples/updating_registrations_on_the_fly/updating_registrations_on_the_fly.go b/examples/updating_registrations_on_the_fly/updating_registrations_on_the_fly.go
index a52edbe..c9af73f 100644
--- a/examples/updating_registrations_on_the_fly/updating_registrations_on_the_fly.go
+++ b/examples/updating_registrations_on_the_fly/updating_registrations_on_the_fly.go
@@ -2,8 +2,8 @@ package main
import (
"fmt"
- mcp_golang "github.com/metoro-io/mcp-golang"
- "github.com/metoro-io/mcp-golang/transport/stdio"
+ mcp_golang "github.com/rvoh-emccaleb/mcp-golang"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/stdio"
"time"
)
diff --git a/go.mod b/go.mod
index 9e12af5..87d4405 100644
--- a/go.mod
+++ b/go.mod
@@ -1,8 +1,10 @@
-module github.com/metoro-io/mcp-golang
+module github.com/rvoh-emccaleb/mcp-golang
-go 1.21
+go 1.24.0
require (
+ github.com/davecgh/go-spew v1.1.1
+ github.com/gin-gonic/gin v1.8.1
github.com/invopop/jsonschema v0.12.0
github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.9.0
@@ -12,9 +14,7 @@ require (
require (
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
- github.com/davecgh/go-spew v1.1.1 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
- github.com/gin-gonic/gin v1.8.1 // indirect
github.com/go-playground/locales v0.14.0 // indirect
github.com/go-playground/universal-translator v0.18.0 // indirect
github.com/go-playground/validator/v10 v10.10.0 // indirect
@@ -32,10 +32,10 @@ require (
github.com/tidwall/pretty v1.2.1 // indirect
github.com/ugorji/go/codec v1.2.7 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
- golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 // indirect
- golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 // indirect
- golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069 // indirect
- golang.org/x/text v0.3.6 // indirect
+ golang.org/x/crypto v0.36.0 // indirect
+ golang.org/x/net v0.37.0 // indirect
+ golang.org/x/sys v0.31.0 // indirect
+ golang.org/x/text v0.23.0 // indirect
google.golang.org/protobuf v1.28.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
diff --git a/go.sum b/go.sum
index 309c382..9ed58ad 100644
--- a/go.sum
+++ b/go.sum
@@ -10,6 +10,7 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.8.1 h1:4+fr/el88TOO3ewCmQr8cx/CtZ/umlIRIs5M4NTNjf8=
github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk=
+github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU=
github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs=
@@ -20,10 +21,9 @@ github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXS
github.com/goccy/go-json v0.9.7 h1:IcB+Aqpx/iMHu5Yooh7jEzJk1JZ7Pjtmys2ukPr7EeM=
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
+github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
-github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
-github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/invopop/jsonschema v0.12.0 h1:6ovsNSuvn9wEQVOyc72aycBMVQFKz7cPdMJn10CvzRI=
github.com/invopop/jsonschema v0.12.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
@@ -31,9 +31,11 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
+github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
+github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w=
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
@@ -53,6 +55,7 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
+github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
@@ -76,32 +79,32 @@ github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
-go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
-go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
-go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
-go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
-golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
-golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw=
+golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
+golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
+golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
+golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069 h1:siQdpVirKtzPhKl3lZWozZraCFObP8S1v6PRp0bLrtU=
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
+golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
+golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
-golang.org/x/tools v0.28.0/go.mod h1:dcIOrVd3mfQKTgrDVQHqCPMWy6lnhfhtX3hLXYVLfRw=
+golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
-gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
diff --git a/integration_test.go b/integration_test.go
index 30c7040..2ed31d9 100644
--- a/integration_test.go
+++ b/integration_test.go
@@ -15,7 +15,7 @@ import (
"testing"
"time"
- "github.com/metoro-io/mcp-golang/transport"
+ "github.com/rvoh-emccaleb/mcp-golang/transport"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -23,8 +23,8 @@ import (
const testServerCode = `package main
import (
- mcp "github.com/metoro-io/mcp-golang"
- "github.com/metoro-io/mcp-golang/transport/stdio"
+ mcp "github.com/rvoh-emccaleb/mcp-golang"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/stdio"
)
type EchoArgs struct {
@@ -78,7 +78,7 @@ func TestServerIntegration(t *testing.T) {
require.NoError(t, err, "Failed to initialize module: %s", string(output))
// Replace the dependency with the local version
- cmd = exec.Command("go", "mod", "edit", "-replace", "github.com/metoro-io/mcp-golang="+currentDir)
+ cmd = exec.Command("go", "mod", "edit", "-replace", "github.com/rvoh-emccaleb/mcp-golang="+currentDir)
cmd.Dir = tmpDir
output, err = cmd.CombinedOutput()
require.NoError(t, err, "Failed to replace dependency: %s", string(output))
diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go
index 3172ac4..18b2ff1 100644
--- a/internal/protocol/protocol.go
+++ b/internal/protocol/protocol.go
@@ -72,7 +72,7 @@ import (
"sync"
"time"
- "github.com/metoro-io/mcp-golang/transport"
+ "github.com/rvoh-emccaleb/mcp-golang/transport"
)
const DefaultRequestTimeoutMsec = 60000
diff --git a/internal/protocol/protocol_test.go b/internal/protocol/protocol_test.go
index 887f769..86332d4 100644
--- a/internal/protocol/protocol_test.go
+++ b/internal/protocol/protocol_test.go
@@ -7,8 +7,8 @@ import (
"testing"
"time"
- "github.com/metoro-io/mcp-golang/internal/testingutils"
- "github.com/metoro-io/mcp-golang/transport"
+ "github.com/rvoh-emccaleb/mcp-golang/internal/testingutils"
+ "github.com/rvoh-emccaleb/mcp-golang/transport"
)
// TestProtocol_Connect verifies the basic connection functionality of the Protocol.
diff --git a/internal/testingutils/mock_transport.go b/internal/testingutils/mock_transport.go
index 19196d9..d85eb30 100644
--- a/internal/testingutils/mock_transport.go
+++ b/internal/testingutils/mock_transport.go
@@ -4,7 +4,7 @@ import (
"context"
"sync"
- "github.com/metoro-io/mcp-golang/transport"
+ "github.com/rvoh-emccaleb/mcp-golang/transport"
)
// MockTransport implements Transport interface for testing
diff --git a/resource_template_response_types.go b/resource_template_response_types.go
new file mode 100644
index 0000000..72f29f7
--- /dev/null
+++ b/resource_template_response_types.go
@@ -0,0 +1,51 @@
+package mcp_golang
+
+// The server's response to a resources/templates/list request from the client.
+type ListResourceTemplatesResponse struct {
+ // Templates corresponds to the JSON schema field "templates".
+ Templates []*ResourceTemplateSchema `json:"templates" yaml:"templates" mapstructure:"templates"`
+ // NextCursor is a cursor for pagination. If not nil, there are more templates available.
+ NextCursor *string `json:"nextCursor,omitempty" yaml:"nextCursor,omitempty" mapstructure:"nextCursor,omitempty"`
+}
+
+// A resource template that the server is capable of instantiating.
+type ResourceTemplateSchema struct {
+ // Annotations corresponds to the JSON schema field "annotations".
+ Annotations *Annotations `json:"annotations,omitempty" yaml:"annotations,omitempty" mapstructure:"annotations,omitempty"`
+
+ // A description of what this resource template represents.
+ //
+ // This can be used by clients to improve the LLM's understanding of available
+ // resource templates. It can be thought of like a "hint" to the model.
+ Description *string `json:"description,omitempty" yaml:"description,omitempty" mapstructure:"description,omitempty"`
+
+ // The MIME type of resources created from this template.
+ MimeType *string `json:"mimeType,omitempty" yaml:"mimeType,omitempty" mapstructure:"mimeType,omitempty"`
+
+ // A human-readable name for this resource template.
+ //
+ // This can be used by clients to populate UI elements.
+ Name string `json:"name" yaml:"name" mapstructure:"name"`
+
+ // The URI of this resource template.
+ Uri string `json:"uri" yaml:"uri" mapstructure:"uri"`
+
+ // Parameters that can be used to instantiate this template.
+ Parameters []ResourceTemplateParameter `json:"parameters" yaml:"parameters" mapstructure:"parameters"`
+}
+
+// A parameter that can be used to instantiate a resource template.
+type ResourceTemplateParameter struct {
+ // A description of what this parameter represents.
+ //
+ // This can be used by clients to improve the LLM's understanding of the parameter.
+ Description *string `json:"description,omitempty" yaml:"description,omitempty" mapstructure:"description,omitempty"`
+
+ // A human-readable name for this parameter.
+ //
+ // This can be used by clients to populate UI elements.
+ Name string `json:"name" yaml:"name" mapstructure:"name"`
+
+ // Whether this parameter is required when instantiating the template.
+ Required *bool `json:"required,omitempty" yaml:"required,omitempty" mapstructure:"required,omitempty"`
+}
diff --git a/server.go b/server.go
index af6f1ec..a3dd71a 100644
--- a/server.go
+++ b/server.go
@@ -10,10 +10,10 @@ import (
"strings"
"github.com/invopop/jsonschema"
- "github.com/metoro-io/mcp-golang/internal/datastructures"
- "github.com/metoro-io/mcp-golang/internal/protocol"
- "github.com/metoro-io/mcp-golang/transport"
"github.com/pkg/errors"
+ "github.com/rvoh-emccaleb/mcp-golang/internal/datastructures"
+ "github.com/rvoh-emccaleb/mcp-golang/internal/protocol"
+ "github.com/rvoh-emccaleb/mcp-golang/transport"
)
// Here we define the actual MCP server that users will create and run
@@ -107,6 +107,7 @@ type Server struct {
tools *datastructures.SyncMap[string, *tool]
prompts *datastructures.SyncMap[string, *prompt]
resources *datastructures.SyncMap[string, *resource]
+ resourceTemplates *datastructures.SyncMap[string, *resourceTemplate]
serverInstructions *string
serverName string
serverVersion string
@@ -134,6 +135,14 @@ type resource struct {
Handler func(context.Context) *resourceResponseSent
}
+type resourceTemplate struct {
+ Name string
+ Description string
+ Uri string
+ MimeType string
+ Parameters []ResourceTemplateParameter
+}
+
type ServerOptions func(*Server)
func WithProtocol(protocol *protocol.Protocol) ServerOptions {
@@ -163,11 +172,12 @@ func WithVersion(version string) ServerOptions {
func NewServer(transport transport.Transport, options ...ServerOptions) *Server {
server := &Server{
- protocol: protocol.NewProtocol(nil),
- transport: transport,
- tools: new(datastructures.SyncMap[string, *tool]),
- prompts: new(datastructures.SyncMap[string, *prompt]),
- resources: new(datastructures.SyncMap[string, *resource]),
+ protocol: protocol.NewProtocol(nil),
+ transport: transport,
+ tools: new(datastructures.SyncMap[string, *tool]),
+ prompts: new(datastructures.SyncMap[string, *prompt]),
+ resources: new(datastructures.SyncMap[string, *resource]),
+ resourceTemplates: new(datastructures.SyncMap[string, *resourceTemplate]),
}
for _, option := range options {
option(server)
@@ -200,6 +210,24 @@ func (s *Server) sendToolListChangedNotification() error {
return s.protocol.Notification("notifications/tools/list_changed", nil)
}
+// RegisterToolWithSchema registers a tool with a predefined JSON schema.
+// This is an alternative to RegisterTool, which uses reflection on the handler to create the schema.
+func (s *Server) RegisterToolWithSchema(name string, description string, handler any, schema *jsonschema.Schema) error {
+ err := validateToolHandler(handler)
+ if err != nil {
+ return err
+ }
+
+ s.tools.Store(name, &tool{
+ Name: name,
+ Description: description,
+ Handler: createWrappedToolHandler(handler),
+ ToolInputSchema: schema, // Use the provided schema directly
+ })
+
+ return s.sendToolListChangedNotification()
+}
+
func (s *Server) CheckToolRegistered(name string) bool {
_, ok := s.tools.Load(name)
return ok
@@ -244,8 +272,9 @@ func (s *Server) DeregisterResource(uri string) error {
func createWrappedResourceHandler(userHandler any) func(ctx context.Context) *resourceResponseSent {
handlerValue := reflect.ValueOf(userHandler)
+ handlerType := handlerValue.Type()
+
return func(ctx context.Context) *resourceResponseSent {
- handlerType := handlerValue.Type()
var args []reflect.Value
if handlerType.NumIn() == 1 {
args = []reflect.Value{reflect.ValueOf(ctx)}
@@ -554,6 +583,15 @@ func (s *Server) Serve() error {
pr.SetRequestHandler("prompts/get", s.handlePromptCalls)
pr.SetRequestHandler("resources/list", s.handleListResources)
pr.SetRequestHandler("resources/read", s.handleResourceCalls)
+ pr.SetRequestHandler("resources/templates/list", s.handleListResourceTemplates)
+
+ // Note: All notifications in MCP, other than this one, are sent from the server to the client, and can use SSE.
+ // This is the only notification that is sent from the client to the server. In order to make things work
+ // smoothly from the client's perspective, we choose to handle this notification as a request instead of
+ // as a notification, because working with the HTTP protocol is inherently request/response oriented, and the client
+ // would otherwise have to guess as to how long to wait before assuming the server is initialized.
+ pr.SetRequestHandler("notifications/initialized", s.handleInitializedNotification)
+
err := pr.Connect(s.transport)
if err != nil {
return err
@@ -576,6 +614,16 @@ func (s *Server) handleInitialize(ctx context.Context, request *transport.BaseJS
}, nil
}
+// handleInitializedNotification is a request handler for the "notifications/initialized" notification.
+// This is the only notification in MCP that is sent from the client to the server. In order to make things work
+// smoothly from the client's perspective, we choose to handle this notification as a request instead of
+// as a notification, because working with the HTTP protocol is inherently request/response oriented, and the client
+// would otherwise have to guess as to how long to wait before assuming the server is initialized. We just return
+// an empty response body to the client to indicate that the notification has been received.
+func (s *Server) handleInitializedNotification(ctx context.Context, request *transport.BaseJSONRPCRequest, _ protocol.RequestHandlerExtra) (transport.JsonRpcBody, error) {
+ return map[string]interface{}{}, nil
+}
+
func (s *Server) handleListTools(ctx context.Context, request *transport.BaseJSONRPCRequest, _ protocol.RequestHandlerExtra) (transport.JsonRpcBody, error) {
type toolRequestParams struct {
Cursor *string `json:"cursor"`
@@ -698,9 +746,13 @@ func (s *Server) handleListPrompts(ctx context.Context, request *transport.BaseJ
Cursor *string `json:"cursor"`
}
var params promptRequestParams
- err := json.Unmarshal(request.Params, ¶ms)
- if err != nil {
- return nil, errors.Wrap(err, "failed to unmarshal arguments")
+ if request.Params == nil {
+ params = promptRequestParams{}
+ } else {
+ err := json.Unmarshal(request.Params, ¶ms)
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to unmarshal arguments")
+ }
}
// Order by name for pagination
@@ -762,9 +814,13 @@ func (s *Server) handleListResources(ctx context.Context, request *transport.Bas
Cursor *string `json:"cursor"`
}
var params resourceRequestParams
- err := json.Unmarshal(request.Params, ¶ms)
- if err != nil {
- return nil, errors.Wrap(err, "failed to unmarshal arguments")
+ if request.Params == nil {
+ params = resourceRequestParams{}
+ } else {
+ err := json.Unmarshal(request.Params, ¶ms)
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to unmarshal arguments")
+ }
}
// Order by URI for pagination
@@ -871,6 +927,78 @@ func (s *Server) handleResourceCalls(ctx context.Context, req *transport.BaseJSO
return resourceToUse.Handler(ctx), nil
}
+func (s *Server) handleListResourceTemplates(ctx context.Context, request *transport.BaseJSONRPCRequest, extra protocol.RequestHandlerExtra) (transport.JsonRpcBody, error) {
+ type templateRequestParams struct {
+ Cursor *string `json:"cursor"`
+ }
+ var params templateRequestParams
+ if request.Params == nil {
+ params = templateRequestParams{}
+ } else {
+ err := json.Unmarshal(request.Params, ¶ms)
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to unmarshal arguments")
+ }
+ }
+
+ // Order by URI for pagination
+ var orderedTemplates []*resourceTemplate
+ s.resourceTemplates.Range(func(k string, t *resourceTemplate) bool {
+ orderedTemplates = append(orderedTemplates, t)
+ return true
+ })
+ sort.Slice(orderedTemplates, func(i, j int) bool {
+ return orderedTemplates[i].Uri < orderedTemplates[j].Uri
+ })
+
+ startPosition := 0
+ if params.Cursor != nil {
+ // Base64 decode the cursor
+ c, err := base64.StdEncoding.DecodeString(*params.Cursor)
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to decode cursor")
+ }
+ cString := string(c)
+ // Iterate through the templates until we find an entry > the cursor
+ for i := 0; i < len(orderedTemplates); i++ {
+ if orderedTemplates[i].Uri > cString {
+ startPosition = i
+ break
+ }
+ }
+ }
+ endPosition := len(orderedTemplates)
+ if s.paginationLimit != nil {
+ // Make sure we don't go out of bounds
+ if len(orderedTemplates) > startPosition+*s.paginationLimit {
+ endPosition = startPosition + *s.paginationLimit
+ }
+ }
+
+ templatesToReturn := make([]*ResourceTemplateSchema, 0)
+ for i := startPosition; i < endPosition; i++ {
+ t := orderedTemplates[i]
+ templatesToReturn = append(templatesToReturn, &ResourceTemplateSchema{
+ Description: &t.Description,
+ MimeType: &t.MimeType,
+ Name: t.Name,
+ Uri: t.Uri,
+ Parameters: t.Parameters,
+ })
+ }
+
+ return ListResourceTemplatesResponse{
+ Templates: templatesToReturn,
+ NextCursor: func() *string {
+ if s.paginationLimit != nil && len(templatesToReturn) >= *s.paginationLimit {
+ toString := base64.StdEncoding.EncodeToString([]byte(templatesToReturn[len(templatesToReturn)-1].Uri))
+ return &toString
+ }
+ return nil
+ }(),
+ }, nil
+}
+
func (s *Server) handlePing(ctx context.Context, request *transport.BaseJSONRPCRequest, extra protocol.RequestHandlerExtra) (transport.JsonRpcBody, error) {
return map[string]interface{}{}, nil
}
@@ -927,3 +1055,28 @@ var (
CommentMap: nil,
}
)
+
+func (s *Server) RegisterResourceTemplate(uri string, name string, description string, mimeType string, parameters []ResourceTemplateParameter) error {
+ s.resourceTemplates.Store(uri, &resourceTemplate{
+ Name: name,
+ Description: description,
+ Uri: uri,
+ MimeType: mimeType,
+ Parameters: parameters,
+ })
+ return s.sendResourceTemplateListChangedNotification()
+}
+
+func (s *Server) DeregisterResourceTemplate(uri string) error {
+ s.resourceTemplates.Delete(uri)
+ return s.sendResourceTemplateListChangedNotification()
+}
+
+func (s *Server) sendResourceTemplateListChangedNotification() error {
+ if !s.isRunning {
+ return nil
+ }
+
+ // Note: Unsure if this one is in spec.
+ return s.protocol.Notification("notifications/resources/templates/list_changed", nil)
+}
diff --git a/server_test.go b/server_test.go
index 2da0c82..8634d4d 100644
--- a/server_test.go
+++ b/server_test.go
@@ -4,9 +4,9 @@ import (
"context"
"testing"
- "github.com/metoro-io/mcp-golang/internal/protocol"
- "github.com/metoro-io/mcp-golang/internal/testingutils"
- "github.com/metoro-io/mcp-golang/transport"
+ "github.com/rvoh-emccaleb/mcp-golang/internal/protocol"
+ "github.com/rvoh-emccaleb/mcp-golang/internal/testingutils"
+ "github.com/rvoh-emccaleb/mcp-golang/transport"
)
func TestServerListChangedNotifications(t *testing.T) {
diff --git a/transport/base/base.go b/transport/base/base.go
new file mode 100644
index 0000000..7baf145
--- /dev/null
+++ b/transport/base/base.go
@@ -0,0 +1,144 @@
+package base
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "sync"
+
+ "github.com/rvoh-emccaleb/mcp-golang/transport"
+)
+
+// Transport implements the common functionality for transports
+type Transport struct {
+ messageHandler func(ctx context.Context, message *transport.BaseJsonRpcMessage)
+ errorHandler func(error)
+ closeHandler func()
+ mu sync.RWMutex
+ ResponseMap map[int64]chan *transport.BaseJsonRpcMessage
+}
+
+// NewTransport creates a new base transport
+func NewTransport() *Transport {
+ return &Transport{
+ ResponseMap: make(map[int64]chan *transport.BaseJsonRpcMessage),
+ }
+}
+
+// Send implements Transport.Send
+func (t *Transport) Send(ctx context.Context, message *transport.BaseJsonRpcMessage) error {
+ key := message.JsonRpcResponse.Id
+ responseChannel := t.ResponseMap[int64(key)]
+ if responseChannel == nil {
+ return fmt.Errorf("no response channel found for key: %d", key)
+ }
+ responseChannel <- message
+ return nil
+}
+
+// Close implements Transport.Close
+func (t *Transport) Close() error {
+ if t.closeHandler != nil {
+ t.closeHandler()
+ }
+ return nil
+}
+
+// SetCloseHandler implements Transport.SetCloseHandler
+func (t *Transport) SetCloseHandler(handler func()) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.closeHandler = handler
+}
+
+// SetErrorHandler implements Transport.SetErrorHandler
+func (t *Transport) SetErrorHandler(handler func(error)) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.errorHandler = handler
+}
+
+// SetMessageHandler implements Transport.SetMessageHandler
+func (t *Transport) SetMessageHandler(handler func(ctx context.Context, message *transport.BaseJsonRpcMessage)) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.messageHandler = handler
+}
+
+// HandleMessage processes an incoming message and returns a response
+func (t *Transport) HandleMessage(ctx context.Context, body []byte) (*transport.BaseJsonRpcMessage, error) {
+ // Try to unmarshal as a request first
+ var request transport.BaseJSONRPCRequest
+ if err := json.Unmarshal(body, &request); err == nil {
+ // Create a response channel for this request
+ t.mu.Lock()
+ var key int64 = 0
+ for key < 1000000 {
+ if _, ok := t.ResponseMap[key]; !ok {
+ break
+ }
+ key = key + 1
+ }
+ t.ResponseMap[key] = make(chan *transport.BaseJsonRpcMessage)
+ t.mu.Unlock()
+
+ originalID := request.Id
+ request.Id = transport.RequestId(key)
+ t.mu.RLock()
+ handler := t.messageHandler
+ t.mu.RUnlock()
+
+ if handler != nil {
+ handler(ctx, transport.NewBaseMessageRequest(&request))
+ }
+
+ // Block until the response is received
+ responseToUse := <-t.ResponseMap[key]
+ delete(t.ResponseMap, key)
+
+ // Restore the original client ID in the response
+ responseToUse.JsonRpcResponse.Id = originalID
+ return responseToUse, nil
+ }
+
+ // Try as a notification
+ var notification transport.BaseJSONRPCNotification
+ if err := json.Unmarshal(body, ¬ification); err == nil {
+ t.mu.RLock()
+ handler := t.messageHandler
+ t.mu.RUnlock()
+
+ if handler != nil {
+ handler(ctx, transport.NewBaseMessageNotification(¬ification))
+ }
+ return nil, nil
+ }
+
+ // Try as a response
+ var response transport.BaseJSONRPCResponse
+ if err := json.Unmarshal(body, &response); err == nil {
+ t.mu.RLock()
+ handler := t.messageHandler
+ t.mu.RUnlock()
+
+ if handler != nil {
+ handler(ctx, transport.NewBaseMessageResponse(&response))
+ }
+ return nil, nil
+ }
+
+ // Try as an error
+ var errorResponse transport.BaseJSONRPCError
+ if err := json.Unmarshal(body, &errorResponse); err == nil {
+ t.mu.RLock()
+ handler := t.messageHandler
+ t.mu.RUnlock()
+
+ if handler != nil {
+ handler(ctx, transport.NewBaseMessageError(&errorResponse))
+ }
+ return nil, nil
+ }
+
+ return nil, fmt.Errorf("failed to unmarshal JSON-RPC message, unrecognized type")
+}
diff --git a/transport/http/common.go b/transport/http/common.go
index 30e6d72..49d3af2 100644
--- a/transport/http/common.go
+++ b/transport/http/common.go
@@ -7,7 +7,7 @@ import (
"io"
"sync"
- "github.com/metoro-io/mcp-golang/transport"
+ "github.com/rvoh-emccaleb/mcp-golang/transport"
)
// baseTransport implements the common functionality for HTTP-based transports
diff --git a/transport/http/gin.go b/transport/http/gin.go
index bbc6c7c..43c29ba 100644
--- a/transport/http/gin.go
+++ b/transport/http/gin.go
@@ -7,7 +7,7 @@ import (
"net/http"
"github.com/gin-gonic/gin"
- "github.com/metoro-io/mcp-golang/transport"
+ "github.com/rvoh-emccaleb/mcp-golang/transport"
)
// GinTransport implements a stateless HTTP transport for MCP using Gin
diff --git a/transport/http/http.go b/transport/http/http.go
index 570ae1b..51e6c19 100644
--- a/transport/http/http.go
+++ b/transport/http/http.go
@@ -5,22 +5,14 @@ import (
"encoding/json"
"fmt"
"net/http"
- "sync"
-
- "github.com/metoro-io/mcp-golang/transport"
)
// HTTPTransport implements a stateless HTTP transport for MCP
type HTTPTransport struct {
*baseTransport
- server *http.Server
- endpoint string
- messageHandler func(ctx context.Context, message *transport.BaseJsonRpcMessage)
- errorHandler func(error)
- closeHandler func()
- mu sync.RWMutex
- addr string
- responseMap map[int64]chan *transport.BaseJsonRpcMessage
+ server *http.Server
+ endpoint string
+ addr string
}
// NewHTTPTransport creates a new HTTP transport that listens on the specified endpoint
@@ -29,7 +21,6 @@ func NewHTTPTransport(endpoint string) *HTTPTransport {
baseTransport: newBaseTransport(),
endpoint: endpoint,
addr: ":8080", // Default port
- responseMap: make(map[int64]chan *transport.BaseJsonRpcMessage),
}
}
@@ -52,17 +43,6 @@ func (t *HTTPTransport) Start(ctx context.Context) error {
return t.server.ListenAndServe()
}
-// Send implements Transport.Send
-func (t *HTTPTransport) Send(ctx context.Context, message *transport.BaseJsonRpcMessage) error {
- key := message.JsonRpcResponse.Id
- responseChannel := t.responseMap[int64(key)]
- if responseChannel == nil {
- return fmt.Errorf("no response channel found for key: %d", key)
- }
- responseChannel <- message
- return nil
-}
-
// Close implements Transport.Close
func (t *HTTPTransport) Close() error {
if t.server != nil {
@@ -70,31 +50,8 @@ func (t *HTTPTransport) Close() error {
return err
}
}
- if t.closeHandler != nil {
- t.closeHandler()
- }
- return nil
-}
-
-// SetCloseHandler implements Transport.SetCloseHandler
-func (t *HTTPTransport) SetCloseHandler(handler func()) {
- t.mu.Lock()
- defer t.mu.Unlock()
- t.closeHandler = handler
-}
-
-// SetErrorHandler implements Transport.SetErrorHandler
-func (t *HTTPTransport) SetErrorHandler(handler func(error)) {
- t.mu.Lock()
- defer t.mu.Unlock()
- t.errorHandler = handler
-}
-// SetMessageHandler implements Transport.SetMessageHandler
-func (t *HTTPTransport) SetMessageHandler(handler func(ctx context.Context, message *transport.BaseJsonRpcMessage)) {
- t.mu.Lock()
- defer t.mu.Unlock()
- t.messageHandler = handler
+ return t.baseTransport.Close()
}
func (t *HTTPTransport) handleRequest(w http.ResponseWriter, r *http.Request) {
diff --git a/transport/http/http_client.go b/transport/http/http_client.go
index f01416a..f17474a 100644
--- a/transport/http/http_client.go
+++ b/transport/http/http_client.go
@@ -4,32 +4,45 @@ import (
"bytes"
"context"
"encoding/json"
+ "errors"
"fmt"
"io"
"net/http"
"sync"
+ "time"
- "github.com/metoro-io/mcp-golang/transport"
+ "github.com/rvoh-emccaleb/mcp-golang/transport"
)
// HTTPClientTransport implements a client-side HTTP transport for MCP
type HTTPClientTransport struct {
- baseURL string
- endpoint string
- messageHandler func(ctx context.Context, message *transport.BaseJsonRpcMessage)
- errorHandler func(error)
- closeHandler func()
- mu sync.RWMutex
- client *http.Client
- headers map[string]string
+ baseURL string
+ endpoint string
+ messageHandler func(ctx context.Context, message *transport.BaseJsonRpcMessage)
+ errorHandler func(error)
+ closeHandler func()
+ mu sync.RWMutex
+ client *http.Client
+ notificationClient *http.Client
+ headers map[string]string
}
// NewHTTPClientTransport creates a new HTTP client transport that connects to the specified endpoint
-func NewHTTPClientTransport(endpoint string) *HTTPClientTransport {
+func NewHTTPClientTransport(endpoint string, notificationTimeout time.Duration) *HTTPClientTransport {
+ if notificationTimeout <= 0 {
+ notificationTimeout = 1 * time.Millisecond // This is flaky, but it works for local demos.
+ }
+
return &HTTPClientTransport{
endpoint: endpoint,
client: &http.Client{},
- headers: make(map[string]string),
+ notificationClient: &http.Client{
+ Transport: &http.Transport{
+ DisableKeepAlives: true,
+ },
+ Timeout: notificationTimeout,
+ },
+ headers: make(map[string]string),
}
}
@@ -68,6 +81,27 @@ func (t *HTTPClientTransport) Send(ctx context.Context, message *transport.BaseJ
req.Header.Set(key, value)
}
+ // Note: The client usually doesn't send notifications. Really it's only used
+ // for the "notifications/initialized" method. The server may or may not be implemented
+ // to return a response, so we have to rely on having a long enough timeout to ensure the
+ // server has time to process the notification. This is inherently flaky, and should be
+ // improved upon.
+ if message.Type == transport.BaseMessageTypeJSONRPCNotificationType {
+ resp, err := t.notificationClient.Do(req)
+ if err != nil && !errors.Is(err, context.DeadlineExceeded) {
+ if t.errorHandler != nil {
+ t.errorHandler(fmt.Errorf("notification error: %w", err))
+ }
+ }
+ if resp != nil {
+ defer resp.Body.Close()
+ }
+
+ return nil
+ }
+
+ // For non-notifications, continue with normal synchronous request
+
resp, err := t.client.Do(req)
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
diff --git a/transport/sse/client.go b/transport/sse/client.go
new file mode 100644
index 0000000..180bab7
--- /dev/null
+++ b/transport/sse/client.go
@@ -0,0 +1,345 @@
+package sse
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/rvoh-emccaleb/mcp-golang/transport"
+)
+
+// SSEClientTransport implements a client-side SSE transport for MCP
+type SSEClientTransport struct {
+ baseURL string
+ endpoint string
+ postEndpoint string
+ messageHandler func(ctx context.Context, message *transport.BaseJsonRpcMessage)
+ errorHandler func(error)
+ closeHandler func()
+ mu sync.RWMutex
+ client *http.Client
+ notificationClient *http.Client
+ headers map[string]string
+ done chan struct{}
+}
+
+// NewSSEClientTransport creates a new SSE client transport that connects to the specified endpoint
+func NewSSEClientTransport(endpoint string, notificationTimeout time.Duration) *SSEClientTransport {
+ if notificationTimeout <= 0 {
+ notificationTimeout = 1 * time.Millisecond // This is flaky, but it works for local demos.
+ }
+
+ return &SSEClientTransport{
+ endpoint: endpoint,
+ client: &http.Client{},
+ notificationClient: &http.Client{
+ Transport: &http.Transport{
+ DisableKeepAlives: true,
+ },
+ Timeout: notificationTimeout,
+ },
+ headers: make(map[string]string),
+ done: make(chan struct{}),
+ }
+}
+
+// WithBaseURL sets the base URL to connect to
+func (t *SSEClientTransport) WithBaseURL(baseURL string) *SSEClientTransport {
+ t.baseURL = baseURL
+ return t
+}
+
+// WithHeader adds a header to the request
+func (t *SSEClientTransport) WithHeader(key, value string) *SSEClientTransport {
+ t.headers[key] = value
+ return t
+}
+
+// Start implements Transport.Start
+func (t *SSEClientTransport) Start(ctx context.Context) error {
+ url, err := url.JoinPath(t.baseURL, t.endpoint)
+ if err != nil {
+ return fmt.Errorf("failed to construct URL: %w", err)
+ }
+
+ req, err := http.NewRequest(http.MethodGet, url, nil)
+ if err != nil {
+ return fmt.Errorf("failed to create request: %w", err)
+ }
+
+ // Set required SSE headers
+ req.Header.Set("Accept", "text/event-stream")
+ req.Header.Set("Cache-Control", "no-cache")
+ req.Header.Set("Connection", "keep-alive")
+ for key, value := range t.headers {
+ req.Header.Set(key, value)
+ }
+
+ endpointChan := make(chan struct{}, 1)
+
+ // Establish our SSE connection and start reading SSE messages in the background
+ go func() {
+ // Response should return once we receive headers (body can have data written to it over time).
+ resp, err := t.client.Do(req)
+ if err != nil {
+ if t.errorHandler != nil {
+ t.errorHandler(fmt.Errorf("failed to establish SSE connection: %w", err))
+ }
+ return
+ }
+
+ // Validate SSE connection
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ if t.errorHandler != nil {
+ t.errorHandler(fmt.Errorf("server returned error status %d: %s", resp.StatusCode, string(body)))
+ }
+ resp.Body.Close()
+ return
+ }
+
+ contentType := resp.Header.Get("Content-Type")
+ if !strings.Contains(contentType, "text/event-stream") {
+ if t.errorHandler != nil {
+ t.errorHandler(fmt.Errorf("invalid content type for SSE connection: %s", contentType))
+ }
+ resp.Body.Close()
+ return
+ }
+
+ // Start reading and storing SSE messages
+ t.readSSEMessages(resp.Body, endpointChan)
+ }()
+
+ // Wait for either the endpoint event or timeout
+ select {
+ case <-endpointChan:
+ return nil
+ case <-time.After(10 * time.Second):
+ return fmt.Errorf("timeout waiting for endpoint event")
+ case <-ctx.Done():
+ return fmt.Errorf("context cancelled while waiting for endpoint event: %w", ctx.Err())
+ }
+}
+
+// readSSEMessages reads SSE messages from the response body and processes them
+func (t *SSEClientTransport) readSSEMessages(body io.ReadCloser, endpointChan chan<- struct{}) {
+ defer body.Close()
+
+ reader := bufio.NewReader(body)
+ buffer := make([]byte, 4096)
+ var messageBuffer bytes.Buffer
+
+ for {
+ select {
+ case <-t.done:
+ return
+ default:
+ n, err := reader.Read(buffer)
+ if err != nil {
+ if err != io.EOF {
+ if t.errorHandler != nil {
+ t.errorHandler(fmt.Errorf("error reading SSE stream: %w", err))
+ }
+ }
+ return
+ }
+
+ messageBuffer.Write(buffer[:n])
+
+ // Process complete messages (terminated by double newline)
+ for {
+ // Check if we have at least one complete message (terminated by double newline)
+ content := messageBuffer.String()
+ messageEnd := strings.Index(content, "\n\n")
+ if messageEnd == -1 {
+ // No complete message yet, keep reading
+ break
+ }
+
+ message := content[:messageEnd]
+
+ messageBuffer.Reset()
+ messageBuffer.WriteString(content[messageEnd+2:]) // 2 newlines
+
+ // Parse the complete message
+ msg, err := t.parseSSEMessageFromString(message, endpointChan)
+ if err != nil {
+ if t.errorHandler != nil {
+ t.errorHandler(fmt.Errorf("error parsing SSE message: %w", err))
+ }
+ continue
+ }
+
+ // Handle the message based on its type.
+ // Note: The initial endpoint event doesn't require handling, here.
+ // We only need to handle the response and notification types.
+ switch msg.Type {
+ case transport.BaseMessageTypeJSONRPCResponseType,
+ transport.BaseMessageTypeJSONRPCNotificationType,
+ transport.BaseMessageTypeJSONRPCErrorType:
+
+ t.mu.RLock()
+ handler := t.messageHandler
+ t.mu.RUnlock()
+
+ if handler != nil {
+ handler(context.Background(), msg)
+ }
+ }
+ }
+ }
+ }
+}
+
+// parseSSEMessageFromString parses a complete SSE message from a string
+func (t *SSEClientTransport) parseSSEMessageFromString(message string, endpointChan chan<- struct{}) (*transport.BaseJsonRpcMessage, error) {
+ lines := strings.Split(message, "\n")
+ if len(lines) < 2 {
+ return nil, fmt.Errorf("invalid SSE message format: too few lines")
+ }
+
+ // Parse the event type
+ eventType := strings.TrimPrefix(lines[0], "event: ")
+ eventType = strings.TrimSpace(eventType)
+
+ // Parse the data line
+ data := strings.TrimPrefix(lines[1], "data: ")
+ data = strings.TrimSpace(data)
+
+ // Initialize message
+ var msg transport.BaseJsonRpcMessage
+
+ // Handle endpoint events differently
+ if eventType == "endpoint" {
+ t.postEndpoint = data
+ if endpointChan != nil {
+ endpointChan <- struct{}{}
+ }
+ return &msg, nil
+ }
+
+ // For message events, try to unmarshal as JSON RPC
+ var response transport.BaseJSONRPCResponse
+ if err := json.Unmarshal([]byte(data), &response); err == nil {
+ msg.Type = transport.BaseMessageTypeJSONRPCResponseType
+ msg.JsonRpcResponse = &response
+ return &msg, nil
+ }
+
+ // Try as an error
+ var errorResponse transport.BaseJSONRPCError
+ if err := json.Unmarshal([]byte(data), &errorResponse); err == nil {
+ msg.Type = transport.BaseMessageTypeJSONRPCErrorType
+ msg.JsonRpcError = &errorResponse
+ return &msg, nil
+ }
+
+ // Try as a notification
+ var notification transport.BaseJSONRPCNotification
+ if err := json.Unmarshal([]byte(data), ¬ification); err == nil {
+ msg.Type = transport.BaseMessageTypeJSONRPCNotificationType
+ msg.JsonRpcNotification = ¬ification
+ return &msg, nil
+ }
+
+ return nil, fmt.Errorf("unrecognized message type: %s", eventType)
+}
+
+// Send implements Transport.Send
+func (t *SSEClientTransport) Send(ctx context.Context, message *transport.BaseJsonRpcMessage) error {
+ jsonData, err := json.Marshal(message)
+ if err != nil {
+ return fmt.Errorf("failed to marshal message: %w", err)
+ }
+
+ if t.postEndpoint == "" {
+ return fmt.Errorf("post endpoint not set. sse connection not established")
+ }
+
+ url, err := url.JoinPath(t.baseURL, t.postEndpoint)
+ if err != nil {
+ return fmt.Errorf("failed to construct URL: %w", err)
+ }
+
+ req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonData))
+ if err != nil {
+ return fmt.Errorf("failed to create request: %w", err)
+ }
+ req.Header.Set("Content-Type", "application/json")
+ for key, value := range t.headers {
+ req.Header.Set(key, value)
+ }
+
+ // Handle notifications differently with a shorter timeout
+ if message.Type == transport.BaseMessageTypeJSONRPCNotificationType {
+ resp, err := t.notificationClient.Do(req)
+ if err != nil && !errors.Is(err, context.DeadlineExceeded) {
+ t.mu.RLock()
+ if t.errorHandler != nil {
+ t.errorHandler(fmt.Errorf("notification error: %w", err))
+ }
+ t.mu.RUnlock()
+ }
+ if resp != nil {
+ defer resp.Body.Close()
+ }
+ return nil
+ }
+
+ // For non-notifications, continue with normal synchronous request
+ resp, err := t.client.Do(req)
+ if err != nil {
+ return fmt.Errorf("failed to send request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ return fmt.Errorf("server returned error: %s (status: %d)", string(body), resp.StatusCode)
+ }
+
+ // Response content we care about will come in over the SSE stream.
+ // We are done, here.
+
+ return nil
+}
+
+// Close implements Transport.Close
+func (t *SSEClientTransport) Close() error {
+ close(t.done)
+ if t.closeHandler != nil {
+ t.closeHandler()
+ }
+ return nil
+}
+
+// SetCloseHandler implements Transport.SetCloseHandler
+func (t *SSEClientTransport) SetCloseHandler(handler func()) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.closeHandler = handler
+}
+
+// SetErrorHandler implements Transport.SetErrorHandler
+func (t *SSEClientTransport) SetErrorHandler(handler func(error)) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.errorHandler = handler
+}
+
+// SetMessageHandler implements Transport.SetMessageHandler
+func (t *SSEClientTransport) SetMessageHandler(handler func(ctx context.Context, message *transport.BaseJsonRpcMessage)) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.messageHandler = handler
+}
diff --git a/transport/sse/internal/sse/sse.go b/transport/sse/internal/sse/sse.go
deleted file mode 100644
index d42d804..0000000
--- a/transport/sse/internal/sse/sse.go
+++ /dev/null
@@ -1,244 +0,0 @@
-// /*
-// Package mcp implements Server-Sent Events (SSE) transport for JSON-RPC communication.
-//
-// SSE Transport Overview:
-// This implementation provides a bidirectional communication channel between client and server:
-// - Server to Client: Uses Server-Sent Events (SSE) for real-time message streaming
-// - Client to Server: Uses HTTP POST requests for sending messages
-//
-// Key Features:
-// 1. Bidirectional Communication:
-// - SSE for server-to-client streaming (one-way, real-time updates)
-// - HTTP POST endpoints for client-to-server messages
-//
-// 2. Session Management:
-// - Unique session IDs for each connection
-// - Proper connection lifecycle management
-// - Automatic cleanup on connection close
-//
-// 3. Message Handling:
-// - JSON-RPC message format support
-// - Automatic message type detection (request vs response)
-// - Built-in error handling and reporting
-// - Message size limits for security
-//
-// 4. Security Features:
-// - Content-type validation
-// - Message size limits (4MB default)
-// - Error handling for malformed messages
-//
-// Usage Example:
-//
-// // Create a new SSE transport
-// transport, err := NewSSETransport("/messages", responseWriter)
-// if err != nil {
-// log.Fatal(err)
-// }
-//
-// // Set up message handling
-// transport.SetMessageHandler(func(msg JSONRPCMessage) {
-// // Handle incoming messages
-// })
-//
-// // Start the SSE connection
-// if err := transport.Start(context.Background()); err != nil {
-// log.Fatal(err)
-// }
-//
-// // Send a message
-// msg := JSONRPCResponse{
-// Jsonrpc: "2.0",
-// Result: Result{...},
-// Id: 1,
-// }
-// if err := transport.Send(msg); err != nil {
-// log.Fatal(err)
-// }
-//
-// */
-package sse
-
-//
-//import (
-// "context"
-// "encoding/json"
-// "fmt"
-// "github.com/metoro-io/mcp-golang/transport"
-// "net/http"
-// "sync"
-//
-// "github.com/google/uuid"
-//)
-//
-//const (
-// maxMessageSize = 4 * 1024 * 1024 // 4MB
-//)
-//
-//// SSETransport implements a Server-Sent Events transport for JSON-RPC messages
-//type SSETransport struct {
-// endpoint string
-// sessionID string
-// writer http.ResponseWriter
-// flusher http.Flusher
-// mu sync.Mutex
-// isConnected bool
-//
-// // Callbacks
-// closeHandler func()
-// errorHandler func(error)
-// messageHandler func(message *transport.BaseJsonRpcMessage)
-//}
-//
-//// NewSSETransport creates a new SSE transport with the given endpoint and response writer
-//func NewSSETransport(endpoint string, w http.ResponseWriter) (*SSETransport, error) {
-// flusher, ok := w.(http.Flusher)
-// if !ok {
-// return nil, fmt.Errorf("streaming not supported")
-// }
-//
-// return &SSETransport{
-// endpoint: endpoint,
-// sessionID: uuid.New().String(),
-// writer: w,
-// flusher: flusher,
-// }, nil
-//}
-//
-//// Start initializes the SSE connection
-//func (t *SSETransport) Start(ctx context.Context) error {
-// t.mu.Lock()
-// defer t.mu.Unlock()
-//
-// if t.isConnected {
-// return fmt.Errorf("SSE transport already started")
-// }
-//
-// // Set SSE headers
-// h := t.writer.Header()
-// h.Set("Content-Type", "text/event-stream")
-// h.Set("Cache-Control", "no-cache")
-// h.Set("Connection", "keep-alive")
-// h.Set("Access-Control-Allow-Origin", "*")
-//
-// // Send the endpoint event
-// endpointURL := fmt.Sprintf("%s?sessionId=%s", t.endpoint, t.sessionID)
-// if err := t.writeEvent("endpoint", endpointURL); err != nil {
-// return err
-// }
-//
-// t.isConnected = true
-//
-// // Handle context cancellation
-// go func() {
-// <-ctx.Done()
-// t.Close()
-// }()
-//
-// return nil
-//}
-//
-//// HandleMessage processes an incoming message
-//func (t *SSETransport) HandleMessage(msg []byte) error {
-// var rpcMsg map[string]interface{}
-// if err := json.Unmarshal(msg, &rpcMsg); err != nil {
-// if t.errorHandler != nil {
-// t.errorHandler(err)
-// }
-// return err
-// }
-//
-// // Parse as a JSONRPCMessage
-// var jsonrpcMsg JSONRPCMessage
-// if _, ok := rpcMsg["method"]; ok {
-// var req JSONRPCRequest
-// if err := json.Unmarshal(msg, &req); err != nil {
-// if t.errorHandler != nil {
-// t.errorHandler(err)
-// }
-// return err
-// }
-// jsonrpcMsg = &req
-// } else {
-// var resp JSONRPCResponse
-// if err := json.Unmarshal(msg, &resp); err != nil {
-// if t.errorHandler != nil {
-// t.errorHandler(err)
-// }
-// return err
-// }
-// jsonrpcMsg = &resp
-// }
-//
-// if t.messageHandler != nil {
-// t.messageHandler(jsonrpcMsg)
-// }
-// return nil
-//}
-//
-//// Send sends a message over the SSE connection
-//func (t *SSETransport) Send(msg JSONRPCMessage) error {
-// t.mu.Lock()
-// defer t.mu.Unlock()
-//
-// if !t.isConnected {
-// return fmt.Errorf("not connected")
-// }
-//
-// data, err := json.Marshal(msg)
-// if err != nil {
-// return err
-// }
-//
-// return t.writeEvent("message", string(data))
-//}
-//
-//// Close closes the SSE connection
-//func (t *SSETransport) Close() error {
-// t.mu.Lock()
-// defer t.mu.Unlock()
-//
-// if !t.isConnected {
-// return nil
-// }
-//
-// t.isConnected = false
-// if t.closeHandler != nil {
-// t.closeHandler()
-// }
-// return nil
-//}
-//
-//// SetCloseHandler sets the callback for when the connection is closed
-//func (t *SSETransport) SetCloseHandler(handler func()) {
-// t.mu.Lock()
-// defer t.mu.Unlock()
-// t.closeHandler = handler
-//}
-//
-//// SetErrorHandler sets the callback for when an error occurs
-//func (t *SSETransport) SetErrorHandler(handler func(error)) {
-// t.mu.Lock()
-// defer t.mu.Unlock()
-// t.errorHandler = handler
-//}
-//
-//// SetMessageHandler sets the callback for when a message is received
-//func (t *SSETransport) SetMessageHandler(handler func(JSONRPCMessage)) {
-// t.mu.Lock()
-// defer t.mu.Unlock()
-// t.messageHandler = handler
-//}
-//
-//// SessionID returns the unique session identifier for this transport
-//func (t *SSETransport) SessionID() string {
-// return t.sessionID
-//}
-//
-//// writeEvent writes an SSE event with the given event type and data
-//func (t *SSETransport) writeEvent(event, data string) error {
-// if _, err := fmt.Fprintf(t.writer, "event: %s\ndata: %s\n\n", event, data); err != nil {
-// return err
-// }
-// t.flusher.Flush()
-// return nil
-//}
diff --git a/transport/sse/server.go b/transport/sse/server.go
new file mode 100644
index 0000000..eba71c2
--- /dev/null
+++ b/transport/sse/server.go
@@ -0,0 +1,332 @@
+package sse
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "sync"
+ "sync/atomic"
+
+ "github.com/rvoh-emccaleb/mcp-golang/transport"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/base"
+)
+
+var (
+ ErrConnectionRemoved = errors.New("sse connection removed")
+)
+
+// sseConnection represents a single SSE client sseConnection, and provides a way to send messages to the client.
+type sseConnection struct {
+ id int64
+ writer http.ResponseWriter
+ flusher http.Flusher
+}
+
+// ServerTransport implements server-side SSE transport
+type ServerTransport struct {
+ *base.Transport // shared transport for handling all messages
+ baseEndpoint string // connection IDs are appended to this endpoint for SSE connections
+ mu sync.RWMutex
+ sseConns map[int64]*sseConnection // map of connection IDs to connections
+ nextConnID int64 // atomic counter for generating connection IDs
+ chunkSize int // size of chunks for writing SSE messages
+}
+
+// Option is a function that configures a ServerTransport
+type Option func(*ServerTransport)
+
+// WithChunkSize sets the chunk size for writing SSE messages
+func WithChunkSize(size int) Option {
+ return func(t *ServerTransport) {
+ if size > 0 {
+ t.chunkSize = size
+ }
+ }
+}
+
+// NewServerTransport creates a new SSE server transport
+func NewServerTransport(endpoint string, options ...Option) *ServerTransport {
+ t := &ServerTransport{
+ Transport: base.NewTransport(),
+ baseEndpoint: endpoint,
+ sseConns: make(map[int64]*sseConnection),
+ chunkSize: 1024, // default to 1KB chunks
+ }
+ for _, opt := range options {
+ opt(t)
+ }
+ return t
+}
+
+// HandleSSEConnInitialize handles a new SSE connection request
+func (t *ServerTransport) HandleSSEConnInitialize(w http.ResponseWriter) (int64, error) {
+ flusher, ok := w.(http.Flusher)
+ if !ok {
+ http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
+ return 0, fmt.Errorf("streaming is not supported: provided http.ResponseWriter does not implement http.Flusher")
+ }
+
+ // Generate a unique endpoint URI for this connection
+ connID := atomic.AddInt64(&t.nextConnID, 1)
+ endpointURI := fmt.Sprintf("%s/%d", t.baseEndpoint, connID)
+
+ sseConn := &sseConnection{
+ id: connID,
+ writer: w,
+ flusher: flusher,
+ }
+
+ t.addSSEConnection(sseConn)
+
+ // Set headers for SSE before writing any data
+ h := sseConn.writer.Header()
+ h.Set("Content-Type", "text/event-stream")
+ h.Set("Cache-Control", "no-cache")
+ h.Set("Connection", "keep-alive")
+ h.Set("Access-Control-Allow-Origin", "*")
+ sseConn.writer.WriteHeader(http.StatusOK)
+
+ // Send the initial endpoint event as required by the MCP specification
+ err := t.writeEndpointEvent(sseConn, endpointURI)
+ if err != nil {
+ // We can't change the status code now, so bubble the error up, allowing request
+ // handling to continue without sending a response to the client.
+ return 0, err
+ }
+
+ return connID, nil
+}
+
+// writeEndpointEvent writes the endpoint URI over SSE
+func (t *ServerTransport) writeEndpointEvent(sseConn *sseConnection, endpointURI string) error {
+ _, err := sseConn.writer.Write(fmt.Appendf(nil, "event: endpoint\ndata: %s\n\n", endpointURI))
+ if err != nil {
+ t.RemoveSSEConnection(sseConn.id)
+ return fmt.Errorf(
+ "failed to write endpoint event for connection with id %d: %w. %w",
+ sseConn.id,
+ err,
+ ErrConnectionRemoved,
+ )
+ }
+
+ sseConn.flusher.Flush()
+
+ return nil
+}
+
+// HandleMCPMessage handles incoming MCP (JSON-RPC) messages.
+func (t *ServerTransport) HandleMCPMessage(w http.ResponseWriter, r *http.Request, connID int64) error {
+ // First check if we have an active connection for this ID
+ sseConn, ok := t.getSSEConnection(connID)
+ if !ok {
+ errMsg := fmt.Sprintf("no active connection found for id: %d", connID)
+ http.Error(w, errMsg, http.StatusNotFound)
+ return errors.New(errMsg)
+ }
+
+ if r.Header.Get("Content-Type") != "application/json" {
+ errMsg := fmt.Sprintf("unsupported content type: %s", r.Header.Get("Content-Type"))
+ http.Error(w, errMsg, http.StatusUnsupportedMediaType)
+ return errors.New(errMsg)
+ }
+
+ body, err := io.ReadAll(r.Body)
+ if err != nil {
+ errMsg := "failed to read request body"
+ http.Error(w, errMsg, http.StatusBadRequest)
+ return fmt.Errorf("%s: %w", errMsg, err)
+ }
+ defer r.Body.Close()
+
+ var receivedMsg transport.BaseJsonRpcMessage
+ if err := json.Unmarshal(body, &receivedMsg); err != nil {
+ errMsg := "failed to parse request body as a JSON-RPC message"
+ http.Error(w, errMsg, http.StatusBadRequest)
+ return fmt.Errorf("%s: %w", errMsg, err)
+ }
+
+ // Send 200 OK response back to the POST request after validating the message.
+ w.WriteHeader(http.StatusOK)
+ if _, err := w.Write([]byte{}); err != nil {
+ return fmt.Errorf("failed to write response: %w", err)
+ }
+ if f, ok := w.(http.Flusher); ok {
+ f.Flush()
+ }
+ // Note: Errors from this point on must be sent over the SSE connection.
+
+ response, err := t.HandleMessage(r.Context(), body)
+ if err != nil {
+ switch receivedMsg.Type {
+ case transport.BaseMessageTypeJSONRPCRequestType:
+ newErr := t.writeMessageEvent(sseConn, transport.NewBaseMessageError(&transport.BaseJSONRPCError{
+ Jsonrpc: "2.0",
+ Id: receivedMsg.JsonRpcRequest.Id,
+ Error: transport.BaseJSONRPCErrorInner{
+ Code: -32603, // Internal error
+ Message: fmt.Sprintf("error handling request: %v", err),
+ },
+ }))
+ if newErr != nil {
+ return fmt.Errorf("failed to write error sse message after encountering '%w' error when handling the request: %w", err, newErr)
+ }
+
+ return fmt.Errorf("error handling request: %w", err)
+
+ case transport.BaseMessageTypeJSONRPCNotificationType:
+ return fmt.Errorf("error handling notification: %w", err)
+
+ default:
+ return fmt.Errorf("error handling unknown message type %s: %w", receivedMsg.Type, err)
+ }
+ }
+
+ // Only send response for requests (not for notifications)
+ if receivedMsg.Type == transport.BaseMessageTypeJSONRPCRequestType {
+ err := t.writeMessageEvent(sseConn, response)
+ if err != nil {
+ if errors.Is(err, ErrConnectionRemoved) {
+ return err
+ }
+
+ // Can try to write an error message event, but if that fails, we're out of options.
+ errorMsg := transport.NewBaseMessageError(&transport.BaseJSONRPCError{
+ Jsonrpc: "2.0",
+ Id: receivedMsg.JsonRpcRequest.Id,
+ Error: transport.BaseJSONRPCErrorInner{
+ Code: -32603, // Internal error
+ Message: fmt.Sprintf("failed to write message event: %v", err),
+ },
+ })
+
+ newErr := t.writeMessageEvent(sseConn, errorMsg)
+ if newErr != nil {
+ return fmt.Errorf("failed to write error message event after encountering '%w' error: %w", err, newErr)
+ }
+
+ return err
+ }
+ }
+
+ return nil
+}
+
+// writeMessageEvent writes a JSON-RPC message over SSE
+func (t *ServerTransport) writeMessageEvent(sseConn *sseConnection, message *transport.BaseJsonRpcMessage) error {
+ data, err := json.Marshal(message)
+ if err != nil {
+ return fmt.Errorf("failed to marshal message: %w", err)
+ }
+
+ // Write the event header
+ if _, err := io.WriteString(sseConn.writer, "event: message\ndata: "); err != nil {
+ t.RemoveSSEConnection(sseConn.id)
+ return fmt.Errorf(
+ "failed to write message event for connection with id %d: %w. %w",
+ sseConn.id,
+ err,
+ ErrConnectionRemoved,
+ )
+ }
+
+ // Use the configured chunk size
+ for i := 0; i < len(data); i += t.chunkSize {
+ end := i + t.chunkSize
+ if end > len(data) {
+ end = len(data)
+ }
+
+ if _, err := sseConn.writer.Write(data[i:end]); err != nil {
+ t.RemoveSSEConnection(sseConn.id)
+ return fmt.Errorf(
+ "failed to write message chunk for connection with id %d: %w. %w",
+ sseConn.id,
+ err,
+ ErrConnectionRemoved,
+ )
+ }
+ sseConn.flusher.Flush()
+ }
+
+ // Write the final newlines
+ if _, err := io.WriteString(sseConn.writer, "\n\n"); err != nil {
+ t.RemoveSSEConnection(sseConn.id)
+ return fmt.Errorf(
+ "failed to write message termination for connection with id %d: %w. %w",
+ sseConn.id,
+ err,
+ ErrConnectionRemoved,
+ )
+ }
+
+ sseConn.flusher.Flush()
+ return nil
+}
+
+// Start implements Transport.Start
+func (t *ServerTransport) Start(ctx context.Context) error {
+ // Nothing to do here as connections are established via HandleSSERequest
+ return nil
+}
+
+// Send implements Transport.Send.
+func (t *ServerTransport) Send(ctx context.Context, message *transport.BaseJsonRpcMessage) error {
+ key := message.JsonRpcResponse.Id
+ responseChannel := t.ResponseMap[int64(key)]
+ if responseChannel == nil {
+ return fmt.Errorf("no response channel found for key: %d", key)
+ }
+ responseChannel <- message
+ return nil
+}
+
+// Close implements Transport.Close
+func (t *ServerTransport) Close() error {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Clear all SSE connections
+ t.sseConns = make(map[int64]*sseConnection)
+
+ // Close the base transport
+ return t.Transport.Close()
+}
+
+// SetMessageHandler implements Transport.SetMessageHandler
+func (t *ServerTransport) SetMessageHandler(handler func(ctx context.Context, message *transport.BaseJsonRpcMessage)) {
+ t.Transport.SetMessageHandler(handler)
+}
+
+// SetErrorHandler implements Transport.SetErrorHandler
+func (t *ServerTransport) SetErrorHandler(handler func(error)) {
+ t.Transport.SetErrorHandler(handler)
+}
+
+// SetCloseHandler implements Transport.SetCloseHandler
+func (t *ServerTransport) SetCloseHandler(handler func()) {
+ t.Transport.SetCloseHandler(handler)
+}
+
+func (t *ServerTransport) addSSEConnection(sseConn *sseConnection) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.sseConns[sseConn.id] = sseConn
+}
+
+func (t *ServerTransport) getSSEConnection(connID int64) (*sseConnection, bool) {
+ t.mu.RLock()
+ defer t.mu.RUnlock()
+ sseConn, ok := t.sseConns[connID]
+ return sseConn, ok
+}
+
+// RemoveSSEConnection removes an SSE connection from the transport
+func (t *ServerTransport) RemoveSSEConnection(connID int64) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ delete(t.sseConns, connID)
+}
diff --git a/transport/sse/sse_server.go b/transport/sse/sse_server.go
deleted file mode 100644
index e3fa67a..0000000
--- a/transport/sse/sse_server.go
+++ /dev/null
@@ -1,82 +0,0 @@
-package sse
-
-//
-//import (
-// "context"
-// "fmt"
-// sse2 "github.com/metoro-io/mcp-golang/transport/sse/internal/sse"
-// "io"
-// "net/http"
-//)
-//
-//// SSEServerTransport implements a server-side SSE transport
-//type SSEServerTransport struct {
-// transport *sse2.SSETransport
-//}
-//
-//// NewSSEServerTransport creates a new SSE server transport
-//func NewSSEServerTransport(endpoint string, w http.ResponseWriter) (*SSEServerTransport, error) {
-// transport, err := sse2.NewSSETransport(endpoint, w)
-// if err != nil {
-// return nil, err
-// }
-//
-// return &SSEServerTransport{
-// transport: transport,
-// }, nil
-//}
-//
-//// Start initializes the SSE connection
-//func (s *SSEServerTransport) Start(ctx context.Context) error {
-// return s.transport.Start(ctx)
-//}
-//
-//// HandlePostMessage processes an incoming POST request containing a JSON-RPC message
-//func (s *SSEServerTransport) HandlePostMessage(r *http.Request) error {
-// if r.Method != http.MethodPost {
-// return fmt.Errorf("method not allowed: %s", r.Method)
-// }
-//
-// contentType := r.Header.Get("Content-Type")
-// if contentType != "application/json" {
-// return fmt.Errorf("unsupported Content type: %s", contentType)
-// }
-//
-// body, err := io.ReadAll(io.LimitReader(r.Body, sse2.maxMessageSize))
-// if err != nil {
-// return fmt.Errorf("failed to read request body: %w", err)
-// }
-// defer r.Body.Close()
-//
-// return s.transport.HandleMessage(body)
-//}
-//
-//// Send sends a message over the SSE connection
-//func (s *SSEServerTransport) Send(msg JSONRPCMessage) error {
-// return s.transport.Send(msg)
-//}
-//
-//// Close closes the SSE connection
-//func (s *SSEServerTransport) Close() error {
-// return s.transport.Close()
-//}
-//
-//// SetCloseHandler sets the callback for when the connection is closed
-//func (s *SSEServerTransport) SetCloseHandler(handler func()) {
-// s.transport.SetCloseHandler(handler)
-//}
-//
-//// SetErrorHandler sets the callback for when an error occurs
-//func (s *SSEServerTransport) SetErrorHandler(handler func(error)) {
-// s.transport.SetErrorHandler(handler)
-//}
-//
-//// SetMessageHandler sets the callback for when a message is received
-//func (s *SSEServerTransport) SetMessageHandler(handler func(JSONRPCMessage)) {
-// s.transport.SetMessageHandler(handler)
-//}
-//
-//// SessionID returns the unique session identifier for this transport
-//func (s *SSEServerTransport) SessionID() string {
-// return s.transport.SessionID()
-//}
diff --git a/transport/sse/sse_server_test.go b/transport/sse/sse_server_test.go
deleted file mode 100644
index 101dff0..0000000
--- a/transport/sse/sse_server_test.go
+++ /dev/null
@@ -1,126 +0,0 @@
-package sse
-
-//
-//import (
-// "bytes"
-// "context"
-// "encoding/json"
-// "github.com/metoro-io/mcp-golang"
-// "net/http"
-// "net/http/httptest"
-// "strings"
-// "testing"
-//
-// "github.com/stretchr/testify/assert"
-//)
-//
-//func TestSSEServerTransport(t *testing.T) {
-// t.Run("basic message handling", func(t *testing.T) {
-// w := httptest.NewRecorder()
-// transport, err := NewSSEServerTransport("/messages", w)
-// assert.NoError(t, err)
-//
-// var receivedMsg JSONRPCMessage
-// transport.SetMessageHandler(func(msg JSONRPCMessage) {
-// receivedMsg = msg
-// })
-//
-// ctx := context.Background()
-// err = transport.Start(ctx)
-// assert.NoError(t, err)
-//
-// // Verify SSE headers
-// headers := w.Header()
-// assert.Equal(t, "text/event-stream", headers.Get("Content-Type"))
-// assert.Equal(t, "no-cache", headers.Get("Cache-Control"))
-// assert.Equal(t, "keep-alive", headers.Get("Connection"))
-//
-// // Verify endpoint event was sent
-// body := w.Body.String()
-// assert.Contains(t, body, "event: endpoint")
-// assert.Contains(t, body, "/messages?sessionId=")
-//
-// // Test message handling
-// msg := JSONRPCRequest{
-// Jsonrpc: "2.0",
-// Method: "test",
-// Id: 1,
-// }
-// msgBytes, err := json.Marshal(msg)
-// assert.NoError(t, err)
-//
-// httpReq := httptest.NewRequest(http.MethodPost, "/messages", bytes.NewReader(msgBytes))
-// httpReq.Header.Set("Content-Type", "application/json")
-// err = transport.HandlePostMessage(httpReq)
-// assert.NoError(t, err)
-//
-// // Verify received message
-// rpcReq, ok := receivedMsg.(*JSONRPCRequest)
-// assert.True(t, ok)
-// assert.Equal(t, "2.0", rpcReq.Jsonrpc)
-// assert.Equal(t, mcp.RequestId(1), rpcReq.Id)
-//
-// err = transport.Close()
-// assert.NoError(t, err)
-// })
-//
-// t.Run("send message", func(t *testing.T) {
-// w := httptest.NewRecorder()
-// transport, err := NewSSEServerTransport("/messages", w)
-// assert.NoError(t, err)
-//
-// ctx := context.Background()
-// err = transport.Start(ctx)
-// assert.NoError(t, err)
-//
-// msg := JSONRPCResponse{
-// Jsonrpc: "2.0",
-// Result: Result{AdditionalProperties: map[string]interface{}{"status": "ok"}},
-// Id: 1,
-// }
-//
-// err = transport.Send(msg)
-// assert.NoError(t, err)
-//
-// // Verify output contains the message
-// body := w.Body.String()
-// assert.Contains(t, body, `event: message`)
-// assert.Contains(t, body, `"result":{"AdditionalProperties":{"status":"ok"}}`)
-// })
-//
-// t.Run("error handling", func(t *testing.T) {
-// w := httptest.NewRecorder()
-// transport, err := NewSSEServerTransport("/messages", w)
-// assert.NoError(t, err)
-//
-// var receivedErr error
-// transport.SetErrorHandler(func(err error) {
-// receivedErr = err
-// })
-//
-// ctx := context.Background()
-// err = transport.Start(ctx)
-// assert.NoError(t, err)
-//
-// // Test invalid JSON
-// req := httptest.NewRequest(http.MethodPost, "/messages", strings.NewReader("invalid json"))
-// req.Header.Set("Content-Type", "application/json")
-// err = transport.HandlePostMessage(req)
-// assert.Error(t, err)
-// assert.NotNil(t, receivedErr)
-// assert.Contains(t, receivedErr.Error(), "invalid")
-//
-// // Test invalid Content type
-// req = httptest.NewRequest(http.MethodPost, "/messages", strings.NewReader("{}"))
-// req.Header.Set("Content-Type", "text/plain")
-// err = transport.HandlePostMessage(req)
-// assert.Error(t, err)
-// assert.Contains(t, err.Error(), "unsupported Content type")
-//
-// // Test invalid method
-// req = httptest.NewRequest(http.MethodGet, "/messages", nil)
-// err = transport.HandlePostMessage(req)
-// assert.Error(t, err)
-// assert.Contains(t, err.Error(), "method not allowed")
-// })
-//}
diff --git a/transport/stdio/internal/stdio/stdio.go b/transport/stdio/internal/stdio/stdio.go
index 9b85698..3f08aa2 100644
--- a/transport/stdio/internal/stdio/stdio.go
+++ b/transport/stdio/internal/stdio/stdio.go
@@ -60,7 +60,7 @@ package stdio
import (
"encoding/json"
"errors"
- "github.com/metoro-io/mcp-golang/transport"
+ "github.com/rvoh-emccaleb/mcp-golang/transport"
"sync"
)
diff --git a/transport/stdio/internal/stdio/stdio_test.go b/transport/stdio/internal/stdio/stdio_test.go
index 7412d5b..5106b71 100644
--- a/transport/stdio/internal/stdio/stdio_test.go
+++ b/transport/stdio/internal/stdio/stdio_test.go
@@ -1,7 +1,7 @@
package stdio
import (
- "github.com/metoro-io/mcp-golang/transport"
+ "github.com/rvoh-emccaleb/mcp-golang/transport"
"testing"
"github.com/stretchr/testify/assert"
diff --git a/transport/stdio/stdio_server.go b/transport/stdio/stdio_server.go
index 36d52f2..aa839a4 100644
--- a/transport/stdio/stdio_server.go
+++ b/transport/stdio/stdio_server.go
@@ -9,8 +9,8 @@ import (
"os"
"sync"
- "github.com/metoro-io/mcp-golang/transport"
- "github.com/metoro-io/mcp-golang/transport/stdio/internal/stdio"
+ "github.com/rvoh-emccaleb/mcp-golang/transport"
+ "github.com/rvoh-emccaleb/mcp-golang/transport/stdio/internal/stdio"
)
// StdioServerTransport implements server-side transport for stdio communication
diff --git a/transport/stdio/stdio_server_test.go b/transport/stdio/stdio_server_test.go
index d8674ba..25243b3 100644
--- a/transport/stdio/stdio_server_test.go
+++ b/transport/stdio/stdio_server_test.go
@@ -7,7 +7,7 @@ import (
"testing"
"time"
- "github.com/metoro-io/mcp-golang/transport"
+ "github.com/rvoh-emccaleb/mcp-golang/transport"
"github.com/stretchr/testify/assert"
)
diff --git a/transport/types.go b/transport/types.go
index d2627f8..7f615f7 100644
--- a/transport/types.go
+++ b/transport/types.go
@@ -3,6 +3,7 @@ package transport
import (
"encoding/json"
"errors"
+ "fmt"
)
type JSONRPCMessage interface{}
@@ -193,6 +194,43 @@ func (m *BaseJsonRpcMessage) MarshalJSON() ([]byte, error) {
}
}
+// UnmarshalJSON implements json.Unmarshaler for BaseJsonRpcMessage
+func (m *BaseJsonRpcMessage) UnmarshalJSON(data []byte) error {
+ // Try as a request
+ var request BaseJSONRPCRequest
+ if err := json.Unmarshal(data, &request); err == nil {
+ m.Type = BaseMessageTypeJSONRPCRequestType
+ m.JsonRpcRequest = &request
+ return nil
+ }
+
+ // Try as a notification
+ var notification BaseJSONRPCNotification
+ if err := json.Unmarshal(data, ¬ification); err == nil {
+ m.Type = BaseMessageTypeJSONRPCNotificationType
+ m.JsonRpcNotification = ¬ification
+ return nil
+ }
+
+ // Try as a response
+ var response BaseJSONRPCResponse
+ if err := json.Unmarshal(data, &response); err == nil {
+ m.Type = BaseMessageTypeJSONRPCResponseType
+ m.JsonRpcResponse = &response
+ return nil
+ }
+
+ // Try as an error
+ var errorResponse BaseJSONRPCError
+ if err := json.Unmarshal(data, &errorResponse); err == nil {
+ m.Type = BaseMessageTypeJSONRPCErrorType
+ m.JsonRpcError = &errorResponse
+ return nil
+ }
+
+ return fmt.Errorf("failed to unmarshal JSON-RPC message, unrecognized type")
+}
+
func NewBaseMessageNotification(notification *BaseJSONRPCNotification) *BaseJsonRpcMessage {
return &BaseJsonRpcMessage{
Type: BaseMessageTypeJSONRPCNotificationType,