@@ -15,6 +15,8 @@ import (
15
15
"sync"
16
16
17
17
jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2"
18
+ "golang.org/x/tools/internal/mcp/internal/protocol"
19
+ "golang.org/x/tools/internal/xcontext"
18
20
)
19
21
20
22
// A JSONRPC2 error is an error defined by the JSONRPC2 spec.
@@ -103,31 +105,67 @@ func connect[H handler](ctx context.Context, t Transport, opts *ConnectionOption
103
105
writer = loggingWriter (opts .Logger , writer )
104
106
}
105
107
106
- var h H
108
+ var (
109
+ h H
110
+ preempter canceller
111
+ )
107
112
bind := func (conn * jsonrpc2.Connection ) jsonrpc2.Handler {
108
113
h = b .bind (conn )
114
+ preempter .conn = conn
109
115
return jsonrpc2 .HandlerFunc (h .handle )
110
116
}
111
117
_ = jsonrpc2 .NewConnection (ctx , jsonrpc2.ConnectionConfig {
112
- Reader : reader ,
113
- Writer : writer ,
114
- Closer : stream ,
115
- Bind : bind ,
118
+ Reader : reader ,
119
+ Writer : writer ,
120
+ Closer : stream ,
121
+ Bind : bind ,
122
+ Preempter : & preempter ,
116
123
OnDone : func () {
117
124
b .disconnect (h )
118
125
},
119
126
})
127
+ assert (preempter .conn != nil , "unbound preempter" )
120
128
assert (h != zero , "unbound connection" )
121
129
return h , nil
122
130
}
123
131
132
+ // A canceller is a jsonrpc2.Preempter that cancels in-flight requests on MCP
133
+ // cancelled notifications.
134
+ type canceller struct {
135
+ conn * jsonrpc2.Connection
136
+ }
137
+
138
+ // Preempt implements jsonrpc2.Preempter.
139
+ func (c * canceller ) Preempt (ctx context.Context , req * jsonrpc2.Request ) (result any , err error ) {
140
+ if req .Method == "notifications/cancelled" {
141
+ var params protocol.CancelledParams
142
+ if err := json .Unmarshal (req .Params , & params ); err != nil {
143
+ return nil , err
144
+ }
145
+ id , err := jsonrpc2 .MakeID (params .RequestId )
146
+ if err != nil {
147
+ return nil , err
148
+ }
149
+ go c .conn .Cancel (id )
150
+ }
151
+ return nil , jsonrpc2 .ErrNotHandled
152
+ }
153
+
124
154
// call executes and awaits a jsonrpc2 call on the given connection,
125
155
// translating errors into the mcp domain.
126
156
func call (ctx context.Context , conn * jsonrpc2.Connection , method string , params , result any ) error {
127
- err := conn .Call (ctx , method , params ).Await (ctx , result )
157
+ call := conn .Call (ctx , method , params )
158
+ err := call .Await (ctx , result )
128
159
switch {
129
160
case errors .Is (err , jsonrpc2 .ErrClientClosing ), errors .Is (err , jsonrpc2 .ErrServerClosing ):
130
161
return fmt .Errorf ("calling %q: %w" , method , ErrConnectionClosed )
162
+ case ctx .Err () != nil :
163
+ // Notify the peer of cancellation.
164
+ err := conn .Notify (xcontext .Detach (ctx ), "notifications/cancelled" , & protocol.CancelledParams {
165
+ Reason : ctx .Err ().Error (),
166
+ RequestId : call .ID ().Raw (),
167
+ })
168
+ return errors .Join (ctx .Err (), err )
131
169
case err != nil :
132
170
return fmt .Errorf ("calling %q: %v" , method , err )
133
171
}
0 commit comments