Skip to content

Commit 2863098

Browse files
committed
internal/mcp: implement ping, and test request interleaving
Also validate that requests observe the lifecycle rules defined in the spec. Change-Id: I302282f81e053e76926e0bb295f7dbf8d526f02d Reviewed-on: https://go-review.googlesource.com/c/tools/+/667575 Reviewed-by: Jonathan Amsterdam <[email protected]> Commit-Queue: Robert Findley <[email protected]> Auto-Submit: Robert Findley <[email protected]> TryBot-Bypass: Robert Findley <[email protected]>
1 parent caf7cdc commit 2863098

File tree

7 files changed

+81
-18
lines changed

7 files changed

+81
-18
lines changed

internal/mcp/client.go

+14-4
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,9 @@ func (c *Client) disconnect(sc *ServerConnection) {
8181
// Connect connects the MCP client over the given transport and initializes an
8282
// MCP session.
8383
//
84-
// It returns a connection object that may be used to query the MCP server,
85-
// terminate the connection (with [Connection.Close]), or await server
86-
// termination (with [Connection.Wait]).
84+
// It returns an initialized [ServerConnection] object that may be used to
85+
// query the MCP server, terminate the connection (with [Connection.Close]), or
86+
// await server termination (with [Connection.Wait]).
8787
//
8888
// Typically, it is the responsibility of the client to close the connection
8989
// when it is no longer needed. However, if the connection is closed by the
@@ -105,7 +105,7 @@ func (c *Client) Connect(ctx context.Context, t Transport, opts *ConnectionOptio
105105
if err := call(ctx, sc.conn, "initialize", params, &sc.initializeResult); err != nil {
106106
return nil, err
107107
}
108-
if err := sc.conn.Notify(ctx, "initialized", &protocol.InitializedParams{}); err != nil {
108+
if err := sc.conn.Notify(ctx, "notifications/initialized", &protocol.InitializedParams{}); err != nil {
109109
return nil, err
110110
}
111111
return sc, nil
@@ -135,11 +135,21 @@ func (cc *ServerConnection) Wait() error {
135135
}
136136

137137
func (sc *ServerConnection) handle(ctx context.Context, req *jsonrpc2.Request) (any, error) {
138+
// No need to check that the connection is initialized, since we initialize
139+
// it in Connect.
138140
switch req.Method {
141+
case "ping":
142+
// The spec says that 'ping' expects an empty object result.
143+
return struct{}{}, nil
139144
}
140145
return nil, jsonrpc2.ErrNotHandled
141146
}
142147

148+
// Ping makes an MCP "ping" request to the server.
149+
func (sc *ServerConnection) Ping(ctx context.Context) error {
150+
return call(ctx, sc.conn, "ping", nil, nil)
151+
}
152+
143153
// ListTools lists tools that are currently available on the server.
144154
func (sc *ServerConnection) ListTools(ctx context.Context) ([]protocol.Tool, error) {
145155
var (

internal/mcp/examples/hello/main.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ type SayHiParams struct {
2020
Name string `json:"name" mcp:"the name to say hi to"`
2121
}
2222

23-
func SayHi(ctx context.Context, params *SayHiParams) ([]mcp.Content, error) {
23+
func SayHi(ctx context.Context, cc *mcp.ClientConnection, params *SayHiParams) ([]mcp.Content, error) {
2424
return []mcp.Content{
2525
mcp.TextContent{Text: "Hi " + params.Name},
2626
}, nil

internal/mcp/mcp_test.go

+11-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package mcp
77
import (
88
"context"
99
"errors"
10+
"fmt"
1011
"slices"
1112
"strings"
1213
"sync"
@@ -22,7 +23,10 @@ type hiParams struct {
2223
Name string
2324
}
2425

25-
func sayHi(_ context.Context, v hiParams) ([]Content, error) {
26+
func sayHi(ctx context.Context, cc *ClientConnection, v hiParams) ([]Content, error) {
27+
if err := cc.Ping(ctx); err != nil {
28+
return nil, fmt.Errorf("ping failed: %v", err)
29+
}
2630
return []Content{TextContent{Text: "hi " + v.Name}}, nil
2731
}
2832

@@ -37,7 +41,7 @@ func TestEndToEnd(t *testing.T) {
3741

3842
// The 'fail' tool returns this error.
3943
failure := errors.New("mcp failure")
40-
s.AddTools(MakeTool("fail", "just fail", func(context.Context, struct{}) ([]Content, error) {
44+
s.AddTools(MakeTool("fail", "just fail", func(context.Context, *ClientConnection, struct{}) ([]Content, error) {
4145
return nil, failure
4246
}))
4347

@@ -67,10 +71,15 @@ func TestEndToEnd(t *testing.T) {
6771
if err != nil {
6872
t.Fatal(err)
6973
}
74+
7075
if got := slices.Collect(c.Servers()); len(got) != 1 {
7176
t.Errorf("after connection, Servers() has length %d, want 1", len(got))
7277
}
7378

79+
if err := sc.Ping(ctx); err != nil {
80+
t.Fatalf("ping failed: %v", err)
81+
}
82+
7483
gotTools, err := sc.ListTools(ctx)
7584
if err != nil {
7685
t.Errorf("tools/list failed: %v", err)

internal/mcp/server.go

+45-6
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func (s *Server) listTools(_ context.Context, _ *ClientConnection, params *proto
8282
return res, nil
8383
}
8484

85-
func (s *Server) callTool(ctx context.Context, _ *ClientConnection, params *protocol.CallToolParams) (*protocol.CallToolResult, error) {
85+
func (s *Server) callTool(ctx context.Context, cc *ClientConnection, params *protocol.CallToolParams) (*protocol.CallToolResult, error) {
8686
s.mu.Lock()
8787
var tool *Tool
8888
if i := slices.IndexFunc(s.tools, func(t *Tool) bool {
@@ -95,7 +95,7 @@ func (s *Server) callTool(ctx context.Context, _ *ClientConnection, params *prot
9595
if tool == nil {
9696
return nil, fmt.Errorf("%s: unknown tool %q", jsonrpc2.ErrInvalidParams, params.Name)
9797
}
98-
return tool.Handler(ctx, params.Arguments)
98+
return tool.Handler(ctx, cc, params.Arguments)
9999
}
100100

101101
// Run runs the server over the given transport, which must be persistent.
@@ -148,10 +148,31 @@ type ClientConnection struct {
148148
conn *jsonrpc2.Connection
149149

150150
mu sync.Mutex
151-
initializeParams *protocol.InitializeParams // set once initialize has been received
151+
initializeParams *protocol.InitializeParams
152+
initialized bool
153+
}
154+
155+
// Ping makes an MCP "ping" request to the client.
156+
func (cc *ClientConnection) Ping(ctx context.Context) error {
157+
return call(ctx, cc.conn, "ping", nil, nil)
152158
}
153159

154160
func (cc *ClientConnection) handle(ctx context.Context, req *jsonrpc2.Request) (any, error) {
161+
cc.mu.Lock()
162+
initialized := cc.initialized
163+
cc.mu.Unlock()
164+
165+
// From the spec:
166+
// "The client SHOULD NOT send requests other than pings before the server
167+
// has responded to the initialize request."
168+
switch req.Method {
169+
case "initialize", "ping":
170+
default:
171+
if !initialized {
172+
return nil, fmt.Errorf("method %q is invalid during session ininitialization", req.Method)
173+
}
174+
}
175+
155176
// TODO: embed the incoming request ID in the ClientContext (or, more likely,
156177
// a wrapper around it), so that we can correlate responses and notifications
157178
// to the handler; this is required for the new session-based transport.
@@ -160,6 +181,10 @@ func (cc *ClientConnection) handle(ctx context.Context, req *jsonrpc2.Request) (
160181
case "initialize":
161182
return dispatch(ctx, cc, req, cc.initialize)
162183

184+
case "ping":
185+
// The spec says that 'ping' expects an empty object result.
186+
return struct{}{}, nil
187+
163188
case "tools/list":
164189
return dispatch(ctx, cc, req, cc.server.listTools)
165190

@@ -176,6 +201,17 @@ func (cc *ClientConnection) initialize(ctx context.Context, _ *ClientConnection,
176201
cc.initializeParams = params
177202
cc.mu.Unlock()
178203

204+
// Mark the connection as initialized when this method exits. TODO:
205+
// Technically, the server should not be considered initialized until it has
206+
// *responded*, but we don't have adequate visibility into the jsonrpc2
207+
// connection to implement that easily. In any case, once we've initialized
208+
// here, we can handle requests.
209+
defer func() {
210+
cc.mu.Lock()
211+
cc.initialized = true
212+
cc.mu.Unlock()
213+
}()
214+
179215
return &protocol.InitializeResult{
180216
// TODO(rfindley): support multiple protocol versions.
181217
ProtocolVersion: "2024-11-05",
@@ -204,11 +240,14 @@ func (cc *ClientConnection) Wait() error {
204240
return cc.conn.Wait()
205241
}
206242

207-
func dispatch[TParams, TResult any](ctx context.Context, conn *ClientConnection, req *jsonrpc2.Request, f func(context.Context, *ClientConnection, TParams) (TResult, error)) (TResult, error) {
243+
// dispatch turns a strongly type handler into a jsonrpc2 handler.
244+
//
245+
// Importantly, it returns nil if the handler returned an error, which is a
246+
// requirement of the jsonrpc2 package.
247+
func dispatch[TConn, TParams, TResult any](ctx context.Context, conn TConn, req *jsonrpc2.Request, f func(context.Context, TConn, TParams) (TResult, error)) (any, error) {
208248
var params TParams
209249
if err := json.Unmarshal(req.Params, &params); err != nil {
210-
var zero TResult
211-
return zero, err
250+
return nil, err
212251
}
213252
return f(ctx, conn, params)
214253
}

internal/mcp/sse_test.go

+3
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ func TestSSEServer(t *testing.T) {
4141
if err != nil {
4242
t.Fatal(err)
4343
}
44+
if err := sc.Ping(ctx); err != nil {
45+
t.Fatal(err)
46+
}
4447
cc := <-clients
4548
gotHi, err := sc.CallTool(ctx, "greet", hiParams{"user"})
4649
if err != nil {

internal/mcp/tool.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import (
1313
)
1414

1515
// A ToolHandler handles a call to tools/call.
16-
type ToolHandler func(context.Context, json.RawMessage) (*protocol.CallToolResult, error)
16+
type ToolHandler func(context.Context, *ClientConnection, json.RawMessage) (*protocol.CallToolResult, error)
1717

1818
// A Tool is a tool definition that is bound to a tool handler.
1919
type Tool struct {
@@ -29,17 +29,17 @@ type Tool struct {
2929
// It is the caller's responsibility that the handler request type can produce
3030
// a valid schema, as documented by [jsonschema.ForType]; otherwise, MakeTool
3131
// panics.
32-
func MakeTool[TReq any](name, description string, handler func(context.Context, TReq) ([]Content, error)) *Tool {
32+
func MakeTool[TReq any](name, description string, handler func(context.Context, *ClientConnection, TReq) ([]Content, error)) *Tool {
3333
schema, err := jsonschema.ForType[TReq]()
3434
if err != nil {
3535
panic(err)
3636
}
37-
wrapped := func(ctx context.Context, args json.RawMessage) (*protocol.CallToolResult, error) {
37+
wrapped := func(ctx context.Context, cc *ClientConnection, args json.RawMessage) (*protocol.CallToolResult, error) {
3838
var v TReq
3939
if err := unmarshalSchema(args, schema, &v); err != nil {
4040
return nil, err
4141
}
42-
content, err := handler(ctx, v)
42+
content, err := handler(ctx, cc, v)
4343
if err != nil {
4444
return &protocol.CallToolResult{
4545
Content: marshalContent([]Content{TextContent{Text: err.Error()}}),

internal/mcp/util.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55
package mcp
66

7-
import "crypto/rand"
7+
import (
8+
"crypto/rand"
9+
)
810

911
func assert(cond bool, msg string) {
1012
if !cond {

0 commit comments

Comments
 (0)