Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and enhance cancellation operations across MCP Sessions. #179

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion src/ModelContextProtocol/Client/IMcpClient.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Shared;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert?


namespace ModelContextProtocol.Client;

Expand Down
12 changes: 6 additions & 6 deletions src/ModelContextProtocol/McpEndpointExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,14 @@ public static Task NotifyProgressAsync(
{
Throw.IfNull(endpoint);

return endpoint.SendMessageAsync(new JsonRpcNotification()
{
Method = NotificationMethods.ProgressNotification,
Params = JsonSerializer.SerializeToNode(new ProgressNotification
return endpoint.SendNotificationAsync(
method: NotificationMethods.ProgressNotification,
parameters: new ProgressNotification
{
ProgressToken = progressToken,
Progress = progress,
}, McpJsonUtilities.JsonContext.Default.ProgressNotification),
}, cancellationToken);
},
McpJsonUtilities.JsonContext.Default.ProgressNotification,
cancellationToken: cancellationToken);
}
}
13 changes: 13 additions & 0 deletions src/ModelContextProtocol/Shared/McpSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,18 @@ await _transport.SendMessageAsync(new JsonRpcResponse
}, cancellationToken).ConfigureAwait(false);
}

private void RegisterCancellation(CancellationToken cancellationToken, RequestId requestId)
{
cancellationToken.Register(async () => await SendMessageAsync(new JsonRpcNotification
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's important to Dispose of the CancellationTokenRegistration. This RegisterCancellation method should return that CancellationTokenRegistration, and the call site should use it in a using around the rest of the operation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Register takes an Action. The way this is written, it's creating an async void method. It should instead do something like () => _ = SendMessageAsync(...).

{
Method = NotificationMethods.CancelledNotification,
Params = JsonSerializer.SerializeToNode(new CancelledNotification
{
RequestId = requestId,
}, McpJsonUtilities.JsonContext.Default.CancelledNotification)
}));
}

public IAsyncDisposable RegisterNotificationHandler(string method, Func<JsonRpcNotification, CancellationToken, Task> handler)
{
Throw.IfNullOrWhiteSpace(method);
Expand All @@ -314,6 +326,7 @@ public IAsyncDisposable RegisterNotificationHandler(string method, Func<JsonRpcN
/// <returns>A task containing the server's response.</returns>
public async Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken)
{
RegisterCancellation(cancellationToken, request.Id);
if (!_transport.IsConnected)
{
_logger.EndpointNotConnected(EndpointName);
Expand Down
135 changes: 132 additions & 3 deletions tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using ModelContextProtocol.Server;
using ModelContextProtocol.Tests.Utils;
using Moq;
using System.Buffers;
using System.IO.Pipelines;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
Expand Down Expand Up @@ -228,21 +229,24 @@ public async ValueTask DisposeAsync()
_cts.Dispose();
}

private async Task<IMcpClient> CreateMcpClientForServer()
private async Task<IMcpClient> CreateMcpClientForServer(
McpClientOptions? options = null,
CancellationToken? cancellationToken = default)
Copy link
Contributor

@stephentoub stephentoub Apr 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's very unusual (bordering on never) for CancellationToken parameters to be nullable. Why does it need to be here?

{
return await McpClientFactory.CreateAsync(
new McpServerConfig()
new()
{
Id = "TestServer",
Name = "TestServer",
TransportType = "ignored",
},
clientOptions: options,
createTransportFunc: (_, _) => new StreamClientTransport(
serverInput: _clientToServerPipe.Writer.AsStream(),
serverOutput: _serverToClientPipe.Reader.AsStream(),
LoggerFactory),
loggerFactory: LoggerFactory,
cancellationToken: TestContext.Current.CancellationToken);
cancellationToken: cancellationToken ?? TestContext.Current.CancellationToken);
}

[Fact]
Expand Down Expand Up @@ -374,4 +378,129 @@ public async Task WithDescription_ChangesToolDescription()
Assert.Equal("ToolWithNewDescription", redescribedTool.Description);
Assert.Equal(originalDescription, tool?.Description);
}

[Fact]
public async Task Can_Handle_Notify_Cancel()
{
// Arrange
var token = TestContext.Current.CancellationToken;
TaskCompletionSource<JsonRpcNotification> clientReceived = new();
await using var client = await CreateMcpClientForServer(
options: CreateClientOptions([new(NotificationMethods.CancelledNotification, (notification, cancellationToken) =>
{
clientReceived.TrySetResult(notification);
return clientReceived.Task;
})]),
cancellationToken: token);
CancelledNotification rpcNotification = new()
{
RequestId = new("abc"),
Reason = "Cancelled",
};

// Act
await NotifyClientAsync(
message: NotificationMethods.CancelledNotification,
parameters: rpcNotification,
token: token);
var notification = await clientReceived.Task
.WaitAsync(TimeSpan.FromSeconds(5), token);

// Assert
Assert.NotNull(notification.Params);
// Parse the Params string back to a CancelledNotification
var cancelled = JsonSerializer.Deserialize<CancelledNotification>(notification.Params.ToString());
Assert.NotNull(cancelled);
Assert.Equal(rpcNotification.RequestId.ToString(), cancelled.RequestId.ToString());
Assert.Equal(rpcNotification.Reason, cancelled.Reason);
}

[Fact]
public async Task Should_Not_Intercept_Sent_Notifications()
{
// Arrange
var token = TestContext.Current.CancellationToken;
TaskCompletionSource<JsonRpcNotification> clientReceived = new();
await using var client = await CreateMcpClientForServer(
options: CreateClientOptions([new(NotificationMethods.CancelledNotification, (notification, cancellationToken) =>
{
var exception = new InvalidOperationException("Should not intercept sent notifications");
clientReceived.TrySetException(exception);
return clientReceived.Task;
})]),
cancellationToken: token);

// Act
await client.SendNotificationAsync(
method: NotificationMethods.CancelledNotification,
parameters: new CancelledNotification
{
RequestId = new("abc"),
Reason = "Cancelled",
}, cancellationToken: token);
await Assert.ThrowsAsync<TimeoutException>(
async () => await clientReceived.Task
.WaitAsync(TimeSpan.FromSeconds(5), token));
// Assert
Assert.False(clientReceived.Task.IsCompleted);
}

[Fact]
public async Task Can_Notify_Cancel()
{
// Arrange
var token = TestContext.Current.CancellationToken;
TaskCompletionSource clientReceived = new();
await using var client = await CreateMcpClientForServer(
options: CreateClientOptions(new Dictionary<string, Func<JsonRpcNotification, CancellationToken, Task>>()
{
[NotificationMethods.CancelledNotification] = (notification, cancellationToken) =>
{
InvalidOperationException exception = new("Should not intercept sent notifications");
clientReceived.TrySetException(exception);
return clientReceived.Task;
}
}), cancellationToken: token);
RequestId expectedRequestId = new("abc");
var expectedReason = "Cancelled";

// Act
await client.SendNotificationAsync(
method: NotificationMethods.CancelledNotification,
parameters: new CancelledNotification
{
RequestId = expectedRequestId,
Reason = expectedReason,
}, cancellationToken: token);

// Assert
await Assert.ThrowsAsync<TimeoutException>(
async () => await clientReceived.Task
.WaitAsync(TimeSpan.FromSeconds(3), token));
}

private static McpClientOptions CreateClientOptions(
IEnumerable<KeyValuePair<string, Func<JsonRpcNotification, CancellationToken, Task>>>? notificationHandlers = null)
=> new()
{
Capabilities = new()
{
NotificationHandlers = notificationHandlers ?? [],
},
};

private async Task NotifyClientAsync(
string message, object? parameters = null, CancellationToken token = default)
=> await NotifyPipeAsync(_serverToClientPipe, message, parameters, token);
private async static Task NotifyPipeAsync(
Pipe pipe, string message, object? parameters = null, CancellationToken token = default)
{
var bytes = JsonSerializer.SerializeToUtf8Bytes(new JsonRpcNotification
{
Method = message,
Params = parameters is not null ? JsonSerializer.Serialize(parameters) : null,
});
await pipe.Writer.WriteAsync(bytes, token);
await pipe.Writer.CompleteAsync(); // Signal the end of the message
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CompleteAsync isn't about the end of a message, it's about saying nothing more will ever be written to the pipe. Is that intended?

}
}
78 changes: 78 additions & 0 deletions tests/ModelContextProtocol.Tests/Server/McpServerTests.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Server;
Expand Down Expand Up @@ -687,4 +688,81 @@ await transport.SendMessageAsync(new JsonRpcNotification
await server.DisposeAsync();
await serverTask;
}

[Fact]
public async Task NotifyCancel_Should_Be_Handled()
{
// Arrange
TaskCompletionSource<JsonRpcNotification> notificationReceived = new();
await using TestServerTransport transport = new(LoggerFactory);
transport.OnMessageSent = (message) =>
{
if (message is JsonRpcNotification notification
&& notification.Method == NotificationMethods.CancelledNotification)
notificationReceived.TrySetResult(notification);
};
var options = CreateOptions();
await using var server = McpServerFactory.Create(transport, options, LoggerFactory);

// Act
var token = TestContext.Current.CancellationToken;
Task serverTask = server.RunAsync(token);
await server.SendNotificationAsync(
NotificationMethods.CancelledNotification,
new CancelledNotification
{
RequestId = new("abc"),
Reason = "Cancelled",
}, cancellationToken: token);
await server.DisposeAsync();
await serverTask.WaitAsync(TimeSpan.FromSeconds(1), token);
var notification = await notificationReceived.Task.WaitAsync(TimeSpan.FromSeconds(1), token);

// Assert
var cancelled = JsonSerializer.Deserialize<CancelledNotification>(notification.Params);
Assert.NotNull(cancelled);
Assert.Equal("abc", cancelled.RequestId.ToString());
Assert.Equal("Cancelled", cancelled.Reason);
}

[Fact]
public async Task SendRequest_Should_Notify_When_Cancelled()
{
// Arrange
TaskCompletionSource<JsonRpcNotification> notificationReceived = new();
await using TestServerTransport transport = new(LoggerFactory);
transport.OnMessageSent = (message) =>
{
if (message is JsonRpcNotification notification
&& notification.Method == NotificationMethods.CancelledNotification)
notificationReceived.TrySetResult(notification);
};
var options = CreateOptions();
await using var server = McpServerFactory.Create(transport, options, LoggerFactory);

// Act
var token = TestContext.Current.CancellationToken;
Task serverTask = server.RunAsync(token);
using CancellationTokenSource cts = new();
await cts.CancelAsync();

await Assert.ThrowsAsync<TaskCanceledException>(async () =>
{
await server.SendRequestAsync(new JsonRpcRequest
{
Method = RequestMethods.Ping,
Id = new("abc"),
}, cts.Token);
});

await server.DisposeAsync();
await serverTask.WaitAsync(TimeSpan.FromSeconds(1), token);
var notification = await notificationReceived.Task.WaitAsync(TimeSpan.FromSeconds(1), token);

// Assert
var cancelled = JsonSerializer.Deserialize<CancelledNotification>(notification.Params);
Assert.NotNull(cancelled);
Assert.Equal("abc", cancelled.RequestId.ToString());
Assert.Null(cancelled.Reason);
}
}
Loading