Skip to content

Commit d3a3775

Browse files
committed
internal/mcp: implement cancellation
Use the existing jsonrpc2 preemption mechanism to implement MCP cancellation. This was mostly straightforward, except where I got confused about ID unmarshaling. Leave a note about using omitempty when it is available. Change-Id: I13e073f08c5d5c2cc78d882da4e6ff47f09fb340 Reviewed-on: https://go-review.googlesource.com/c/tools/+/667578 LUCI-TryBot-Result: Go LUCI <[email protected]> Reviewed-by: Jonathan Amsterdam <[email protected]>
1 parent 2863098 commit d3a3775

File tree

8 files changed

+134
-31
lines changed

8 files changed

+134
-31
lines changed

gopls/internal/lsprpc/export_test.go

+2-7
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,8 @@ func (c *Canceler) Preempt(ctx context.Context, req *jsonrpc2_v2.Request) (any,
3434
if err := json.Unmarshal(req.Params, &params); err != nil {
3535
return nil, fmt.Errorf("%w: %v", jsonrpc2_v2.ErrParse, err)
3636
}
37-
var id jsonrpc2_v2.ID
38-
switch raw := params.ID.(type) {
39-
case float64:
40-
id = jsonrpc2_v2.Int64ID(int64(raw))
41-
case string:
42-
id = jsonrpc2_v2.StringID(raw)
43-
default:
37+
id, err := jsonrpc2_v2.MakeID(params.ID)
38+
if err != nil {
4439
return nil, fmt.Errorf("%w: invalid ID type %T", jsonrpc2_v2.ErrParse, params.ID)
4540
}
4641
c.Conn.Cancel(id)

internal/jsonrpc2_v2/messages.go

+25-13
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,32 @@ import (
1010
"fmt"
1111
)
1212

13-
// ID is a Request identifier.
13+
// ID is a Request identifier, which is defined by the spec to be a string, integer, or null.
14+
// https://www.jsonrpc.org/specification#request_object
1415
type ID struct {
1516
value any
1617
}
1718

19+
// MakeID coerces the given Go value to an ID. The value is assumed to be the
20+
// default JSON marshaling of a Request identifier -- nil, float64, or string.
21+
//
22+
// Returns an error if the value type was a valid Request ID type.
23+
//
24+
// TODO: ID can't be a json.Marshaler/Unmarshaler, because we want to omitzero.
25+
// Simplify this package by making ID json serializable once we can rely on
26+
// omitzero.
27+
func MakeID(v any) (ID, error) {
28+
switch v := v.(type) {
29+
case nil:
30+
return ID{}, nil
31+
case float64:
32+
return Int64ID(int64(v)), nil
33+
case string:
34+
return StringID(v), nil
35+
}
36+
return ID{}, fmt.Errorf("%w: invalid ID type %T", ErrParse, v)
37+
}
38+
1839
// Message is the interface to all jsonrpc2 message types.
1940
// They share no common functionality, but are a closed set of concrete types
2041
// that are allowed to implement this interface. The message types are *Request
@@ -133,18 +154,9 @@ func DecodeMessage(data []byte) (Message, error) {
133154
if msg.VersionTag != wireVersion {
134155
return nil, fmt.Errorf("invalid message version tag %s expected %s", msg.VersionTag, wireVersion)
135156
}
136-
id := ID{}
137-
switch v := msg.ID.(type) {
138-
case nil:
139-
case float64:
140-
// coerce the id type to int64 if it is float64, the spec does not allow fractional parts
141-
id = Int64ID(int64(v))
142-
case int64:
143-
id = Int64ID(v)
144-
case string:
145-
id = StringID(v)
146-
default:
147-
return nil, fmt.Errorf("invalid message id type <%T>%v", v, v)
157+
id, err := MakeID(msg.ID)
158+
if err != nil {
159+
return nil, err
148160
}
149161
if msg.Method != "" {
150162
// has a method, must be a call

internal/mcp/internal/protocol/generate.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ var declarations = config{
5656
"CallToolResult": {
5757
Name: "CallToolResult",
5858
},
59+
"CancelledNotification": {
60+
Fields: config{"Params": {Name: "CancelledParams"}},
61+
},
5962
"ClientCapabilities": {Name: "ClientCapabilities"},
6063
"Implementation": {Name: "Implementation"},
6164
"InitializeRequest": {
@@ -73,7 +76,8 @@ var declarations = config{
7376
"ListToolsResult": {
7477
Name: "ListToolsResult",
7578
},
76-
"Role": {Name: "Role"},
79+
"RequestId": {Substitute: "any"}, // null|number|string
80+
"Role": {Name: "Role"},
7781
"ServerCapabilities": {
7882
Name: "ServerCapabilities",
7983
Fields: config{

internal/mcp/internal/protocol/protocol.go

+11
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/mcp/mcp.go

-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
// [Client.Connect] or [Server.Connect].
1313
//
1414
// TODO:
15-
// - Support cancellation.
1615
// - Support pagination.
1716
// - Support all client/server operations.
1817
// - Support Streamable HTTP transport.

internal/mcp/mcp_test.go

+46-2
Original file line numberDiff line numberDiff line change
@@ -133,14 +133,21 @@ func TestEndToEnd(t *testing.T) {
133133
}
134134
}
135135

136-
func TestServerClosing(t *testing.T) {
136+
// basicConnection returns a new basic client-server connection configured with
137+
// the provided tools.
138+
//
139+
// The caller should cancel either the client connection or server connection
140+
// when the connections are no longer needed.
141+
func basicConnection(t *testing.T, tools ...*Tool) (*ClientConnection, *ServerConnection) {
142+
t.Helper()
143+
137144
ctx := context.Background()
138145
ct, st := NewLocalTransport()
139146

140147
s := NewServer("testServer", "v1.0.0", nil)
141148

142149
// The 'greet' tool says hi.
143-
s.AddTools(MakeTool("greet", "say hi", sayHi))
150+
s.AddTools(tools...)
144151
cc, err := s.Connect(ctx, st, nil)
145152
if err != nil {
146153
t.Fatal(err)
@@ -151,7 +158,14 @@ func TestServerClosing(t *testing.T) {
151158
if err != nil {
152159
t.Fatal(err)
153160
}
161+
return cc, sc
162+
}
154163

164+
func TestServerClosing(t *testing.T) {
165+
cc, sc := basicConnection(t, MakeTool("greet", "say hi", sayHi))
166+
defer sc.Close()
167+
168+
ctx := context.Background()
155169
var wg sync.WaitGroup
156170
wg.Add(1)
157171
go func() {
@@ -209,3 +223,33 @@ func TestBatching(t *testing.T) {
209223
}
210224

211225
}
226+
227+
func TestCancellation(t *testing.T) {
228+
var (
229+
start = make(chan struct{})
230+
cancelled = make(chan struct{}, 1) // don't block the request
231+
)
232+
233+
slowRequest := func(ctx context.Context, cc *ClientConnection, v struct{}) ([]Content, error) {
234+
start <- struct{}{}
235+
select {
236+
case <-ctx.Done():
237+
cancelled <- struct{}{}
238+
case <-time.After(5 * time.Second):
239+
return nil, nil
240+
}
241+
return nil, nil
242+
}
243+
_, sc := basicConnection(t, MakeTool("slow", "a slow request", slowRequest))
244+
defer sc.Close()
245+
246+
ctx, cancel := context.WithCancel(context.Background())
247+
go sc.CallTool(ctx, "slow", struct{}{})
248+
<-start
249+
cancel()
250+
select {
251+
case <-cancelled:
252+
case <-time.After(5 * time.Second):
253+
t.Fatal("timeout waiting for cancellation")
254+
}
255+
}

internal/mcp/server.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ func (cc *ClientConnection) Wait() error {
240240
return cc.conn.Wait()
241241
}
242242

243-
// dispatch turns a strongly type handler into a jsonrpc2 handler.
243+
// dispatch turns a strongly type request handler into a jsonrpc2 handler.
244244
//
245245
// Importantly, it returns nil if the handler returned an error, which is a
246246
// requirement of the jsonrpc2 package.

internal/mcp/transport.go

+44-6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ import (
1515
"sync"
1616

1717
jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2"
18+
"golang.org/x/tools/internal/mcp/internal/protocol"
19+
"golang.org/x/tools/internal/xcontext"
1820
)
1921

2022
// A JSONRPC2 error is an error defined by the JSONRPC2 spec.
@@ -103,31 +105,67 @@ func connect[H handler](ctx context.Context, t Transport, opts *ConnectionOption
103105
writer = loggingWriter(opts.Logger, writer)
104106
}
105107

106-
var h H
108+
var (
109+
h H
110+
preempter canceller
111+
)
107112
bind := func(conn *jsonrpc2.Connection) jsonrpc2.Handler {
108113
h = b.bind(conn)
114+
preempter.conn = conn
109115
return jsonrpc2.HandlerFunc(h.handle)
110116
}
111117
_ = jsonrpc2.NewConnection(ctx, jsonrpc2.ConnectionConfig{
112-
Reader: reader,
113-
Writer: writer,
114-
Closer: stream,
115-
Bind: bind,
118+
Reader: reader,
119+
Writer: writer,
120+
Closer: stream,
121+
Bind: bind,
122+
Preempter: &preempter,
116123
OnDone: func() {
117124
b.disconnect(h)
118125
},
119126
})
127+
assert(preempter.conn != nil, "unbound preempter")
120128
assert(h != zero, "unbound connection")
121129
return h, nil
122130
}
123131

132+
// A canceller is a jsonrpc2.Preempter that cancels in-flight requests on MCP
133+
// cancelled notifications.
134+
type canceller struct {
135+
conn *jsonrpc2.Connection
136+
}
137+
138+
// Preempt implements jsonrpc2.Preempter.
139+
func (c *canceller) Preempt(ctx context.Context, req *jsonrpc2.Request) (result any, err error) {
140+
if req.Method == "notifications/cancelled" {
141+
var params protocol.CancelledParams
142+
if err := json.Unmarshal(req.Params, &params); err != nil {
143+
return nil, err
144+
}
145+
id, err := jsonrpc2.MakeID(params.RequestId)
146+
if err != nil {
147+
return nil, err
148+
}
149+
go c.conn.Cancel(id)
150+
}
151+
return nil, jsonrpc2.ErrNotHandled
152+
}
153+
124154
// call executes and awaits a jsonrpc2 call on the given connection,
125155
// translating errors into the mcp domain.
126156
func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params, result any) error {
127-
err := conn.Call(ctx, method, params).Await(ctx, result)
157+
call := conn.Call(ctx, method, params)
158+
err := call.Await(ctx, result)
128159
switch {
129160
case errors.Is(err, jsonrpc2.ErrClientClosing), errors.Is(err, jsonrpc2.ErrServerClosing):
130161
return fmt.Errorf("calling %q: %w", method, ErrConnectionClosed)
162+
case ctx.Err() != nil:
163+
// Notify the peer of cancellation.
164+
err := conn.Notify(xcontext.Detach(ctx), "notifications/cancelled", &protocol.CancelledParams{
165+
Reason: ctx.Err().Error(),
166+
RequestId: call.ID().Raw(),
167+
})
168+
return errors.Join(ctx.Err(), err)
131169
case err != nil:
132170
return fmt.Errorf("calling %q: %v", method, err)
133171
}

0 commit comments

Comments
 (0)