-
Notifications
You must be signed in to change notification settings - Fork 158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix and enhance cancellation operations across MCP Sessions. #179
base: main
Are you sure you want to change the base?
Changes from all commits
90536c2
eae1730
4f3a74e
46cff23
be733e2
3e3f95d
d5ffe1f
a5cd149
4e5f954
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -296,6 +296,18 @@ await _transport.SendMessageAsync(new JsonRpcResponse | |
}, cancellationToken).ConfigureAwait(false); | ||
} | ||
|
||
private void RegisterCancellation(CancellationToken cancellationToken, RequestId requestId) | ||
{ | ||
cancellationToken.Register(async () => await SendMessageAsync(new JsonRpcNotification | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's important to Dispose of the CancellationTokenRegistration. This RegisterCancellation method should return that CancellationTokenRegistration, and the call site should use it in a using around the rest of the operation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Register takes an Action. The way this is written, it's creating an async void method. It should instead do something like |
||
{ | ||
Method = NotificationMethods.CancelledNotification, | ||
Params = JsonSerializer.SerializeToNode(new CancelledNotification | ||
{ | ||
RequestId = requestId, | ||
}, McpJsonUtilities.JsonContext.Default.CancelledNotification) | ||
})); | ||
} | ||
|
||
public IAsyncDisposable RegisterNotificationHandler(string method, Func<JsonRpcNotification, CancellationToken, Task> handler) | ||
{ | ||
Throw.IfNullOrWhiteSpace(method); | ||
|
@@ -314,6 +326,7 @@ public IAsyncDisposable RegisterNotificationHandler(string method, Func<JsonRpcN | |
/// <returns>A task containing the server's response.</returns> | ||
public async Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) | ||
{ | ||
RegisterCancellation(cancellationToken, request.Id); | ||
if (!_transport.IsConnected) | ||
{ | ||
_logger.EndpointNotConnected(EndpointName); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
using ModelContextProtocol.Server; | ||
using ModelContextProtocol.Tests.Utils; | ||
using Moq; | ||
using System.Buffers; | ||
using System.IO.Pipelines; | ||
using System.Text.Json; | ||
using System.Text.Json.Serialization.Metadata; | ||
|
@@ -228,21 +229,24 @@ public async ValueTask DisposeAsync() | |
_cts.Dispose(); | ||
} | ||
|
||
private async Task<IMcpClient> CreateMcpClientForServer() | ||
private async Task<IMcpClient> CreateMcpClientForServer( | ||
McpClientOptions? options = null, | ||
CancellationToken? cancellationToken = default) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's very unusual (bordering on never) for CancellationToken parameters to be nullable. Why does it need to be here? |
||
{ | ||
return await McpClientFactory.CreateAsync( | ||
new McpServerConfig() | ||
new() | ||
{ | ||
Id = "TestServer", | ||
Name = "TestServer", | ||
TransportType = "ignored", | ||
}, | ||
clientOptions: options, | ||
createTransportFunc: (_, _) => new StreamClientTransport( | ||
serverInput: _clientToServerPipe.Writer.AsStream(), | ||
serverOutput: _serverToClientPipe.Reader.AsStream(), | ||
LoggerFactory), | ||
loggerFactory: LoggerFactory, | ||
cancellationToken: TestContext.Current.CancellationToken); | ||
cancellationToken: cancellationToken ?? TestContext.Current.CancellationToken); | ||
} | ||
|
||
[Fact] | ||
|
@@ -374,4 +378,129 @@ public async Task WithDescription_ChangesToolDescription() | |
Assert.Equal("ToolWithNewDescription", redescribedTool.Description); | ||
Assert.Equal(originalDescription, tool?.Description); | ||
} | ||
|
||
[Fact] | ||
public async Task Can_Handle_Notify_Cancel() | ||
{ | ||
// Arrange | ||
var token = TestContext.Current.CancellationToken; | ||
TaskCompletionSource<JsonRpcNotification> clientReceived = new(); | ||
await using var client = await CreateMcpClientForServer( | ||
options: CreateClientOptions([new(NotificationMethods.CancelledNotification, (notification, cancellationToken) => | ||
{ | ||
clientReceived.TrySetResult(notification); | ||
return clientReceived.Task; | ||
})]), | ||
cancellationToken: token); | ||
CancelledNotification rpcNotification = new() | ||
{ | ||
RequestId = new("abc"), | ||
Reason = "Cancelled", | ||
}; | ||
|
||
// Act | ||
await NotifyClientAsync( | ||
message: NotificationMethods.CancelledNotification, | ||
parameters: rpcNotification, | ||
token: token); | ||
var notification = await clientReceived.Task | ||
.WaitAsync(TimeSpan.FromSeconds(5), token); | ||
|
||
// Assert | ||
Assert.NotNull(notification.Params); | ||
// Parse the Params string back to a CancelledNotification | ||
var cancelled = JsonSerializer.Deserialize<CancelledNotification>(notification.Params.ToString()); | ||
Assert.NotNull(cancelled); | ||
Assert.Equal(rpcNotification.RequestId.ToString(), cancelled.RequestId.ToString()); | ||
Assert.Equal(rpcNotification.Reason, cancelled.Reason); | ||
} | ||
|
||
[Fact] | ||
public async Task Should_Not_Intercept_Sent_Notifications() | ||
{ | ||
// Arrange | ||
var token = TestContext.Current.CancellationToken; | ||
TaskCompletionSource<JsonRpcNotification> clientReceived = new(); | ||
await using var client = await CreateMcpClientForServer( | ||
options: CreateClientOptions([new(NotificationMethods.CancelledNotification, (notification, cancellationToken) => | ||
{ | ||
var exception = new InvalidOperationException("Should not intercept sent notifications"); | ||
clientReceived.TrySetException(exception); | ||
return clientReceived.Task; | ||
})]), | ||
cancellationToken: token); | ||
|
||
// Act | ||
await client.SendNotificationAsync( | ||
method: NotificationMethods.CancelledNotification, | ||
parameters: new CancelledNotification | ||
{ | ||
RequestId = new("abc"), | ||
Reason = "Cancelled", | ||
}, cancellationToken: token); | ||
await Assert.ThrowsAsync<TimeoutException>( | ||
async () => await clientReceived.Task | ||
.WaitAsync(TimeSpan.FromSeconds(5), token)); | ||
// Assert | ||
Assert.False(clientReceived.Task.IsCompleted); | ||
} | ||
|
||
[Fact] | ||
public async Task Can_Notify_Cancel() | ||
{ | ||
// Arrange | ||
var token = TestContext.Current.CancellationToken; | ||
TaskCompletionSource clientReceived = new(); | ||
await using var client = await CreateMcpClientForServer( | ||
options: CreateClientOptions(new Dictionary<string, Func<JsonRpcNotification, CancellationToken, Task>>() | ||
{ | ||
[NotificationMethods.CancelledNotification] = (notification, cancellationToken) => | ||
{ | ||
InvalidOperationException exception = new("Should not intercept sent notifications"); | ||
clientReceived.TrySetException(exception); | ||
return clientReceived.Task; | ||
} | ||
}), cancellationToken: token); | ||
RequestId expectedRequestId = new("abc"); | ||
var expectedReason = "Cancelled"; | ||
|
||
// Act | ||
await client.SendNotificationAsync( | ||
method: NotificationMethods.CancelledNotification, | ||
parameters: new CancelledNotification | ||
{ | ||
RequestId = expectedRequestId, | ||
Reason = expectedReason, | ||
}, cancellationToken: token); | ||
|
||
// Assert | ||
await Assert.ThrowsAsync<TimeoutException>( | ||
async () => await clientReceived.Task | ||
.WaitAsync(TimeSpan.FromSeconds(3), token)); | ||
} | ||
|
||
private static McpClientOptions CreateClientOptions( | ||
IEnumerable<KeyValuePair<string, Func<JsonRpcNotification, CancellationToken, Task>>>? notificationHandlers = null) | ||
=> new() | ||
{ | ||
Capabilities = new() | ||
{ | ||
NotificationHandlers = notificationHandlers ?? [], | ||
}, | ||
}; | ||
|
||
private async Task NotifyClientAsync( | ||
string message, object? parameters = null, CancellationToken token = default) | ||
=> await NotifyPipeAsync(_serverToClientPipe, message, parameters, token); | ||
private async static Task NotifyPipeAsync( | ||
Pipe pipe, string message, object? parameters = null, CancellationToken token = default) | ||
{ | ||
var bytes = JsonSerializer.SerializeToUtf8Bytes(new JsonRpcNotification | ||
{ | ||
Method = message, | ||
Params = parameters is not null ? JsonSerializer.Serialize(parameters) : null, | ||
}); | ||
await pipe.Writer.WriteAsync(bytes, token); | ||
await pipe.Writer.CompleteAsync(); // Signal the end of the message | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CompleteAsync isn't about the end of a message, it's about saying nothing more will ever be written to the pipe. Is that intended? |
||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revert?