diff --git a/samples/EverythingServer/LoggingUpdateMessageSender.cs b/samples/EverythingServer/LoggingUpdateMessageSender.cs index 7b64aa2c..efd71149 100644 --- a/samples/EverythingServer/LoggingUpdateMessageSender.cs +++ b/samples/EverythingServer/LoggingUpdateMessageSender.cs @@ -1,5 +1,4 @@ -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Hosting; using ModelContextProtocol; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; diff --git a/samples/EverythingServer/ResourceGenerator.cs b/samples/EverythingServer/ResourceGenerator.cs index 54764b8c..3ae4c17f 100644 --- a/samples/EverythingServer/ResourceGenerator.cs +++ b/samples/EverythingServer/ResourceGenerator.cs @@ -1,7 +1,4 @@ using ModelContextProtocol.Protocol.Types; -using System; -using System.Collections.Generic; -using System.Linq; namespace EverythingServer; diff --git a/samples/EverythingServer/Tools/LongRunningTool.cs b/samples/EverythingServer/Tools/LongRunningTool.cs index 86acc84d..378cef6c 100644 --- a/samples/EverythingServer/Tools/LongRunningTool.cs +++ b/samples/EverythingServer/Tools/LongRunningTool.cs @@ -1,5 +1,4 @@ using ModelContextProtocol; -using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using System.ComponentModel; diff --git a/src/ModelContextProtocol/Client/McpClientFactory.cs b/src/ModelContextProtocol/Client/McpClientFactory.cs index ca9fcc4d..4eef40bb 100644 --- a/src/ModelContextProtocol/Client/McpClientFactory.cs +++ b/src/ModelContextProtocol/Client/McpClientFactory.cs @@ -1,6 +1,4 @@ -using System.Globalization; -using System.Runtime.InteropServices; -using ModelContextProtocol.Logging; +using ModelContextProtocol.Logging; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Utils; using Microsoft.Extensions.Logging; diff --git a/src/ModelContextProtocol/McpEndpointExtensions.cs b/src/ModelContextProtocol/McpEndpointExtensions.cs index 3a7a7721..11c697d4 100644 --- a/src/ModelContextProtocol/McpEndpointExtensions.cs +++ b/src/ModelContextProtocol/McpEndpointExtensions.cs @@ -155,14 +155,14 @@ public static Task NotifyProgressAsync( { Throw.IfNull(endpoint); - return endpoint.SendMessageAsync(new JsonRpcNotification() - { - Method = NotificationMethods.ProgressNotification, - Params = JsonSerializer.SerializeToNode(new ProgressNotification + return endpoint.SendNotificationAsync( + NotificationMethods.ProgressNotification, + new ProgressNotification { ProgressToken = progressToken, Progress = progress, - }, McpJsonUtilities.JsonContext.Default.ProgressNotification), - }, cancellationToken); + }, + McpJsonUtilities.JsonContext.Default.ProgressNotification, + cancellationToken); } } diff --git a/src/ModelContextProtocol/Protocol/Types/ClientCapabilities.cs b/src/ModelContextProtocol/Protocol/Types/ClientCapabilities.cs index cc700e14..0edab54a 100644 --- a/src/ModelContextProtocol/Protocol/Types/ClientCapabilities.cs +++ b/src/ModelContextProtocol/Protocol/Types/ClientCapabilities.cs @@ -1,5 +1,4 @@ using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Server; using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol.Types; diff --git a/src/ModelContextProtocol/Protocol/Types/LoggingCapability.cs b/src/ModelContextProtocol/Protocol/Types/LoggingCapability.cs index 425abca7..7b2531ab 100644 --- a/src/ModelContextProtocol/Protocol/Types/LoggingCapability.cs +++ b/src/ModelContextProtocol/Protocol/Types/LoggingCapability.cs @@ -1,5 +1,4 @@ -using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Server; +using ModelContextProtocol.Server; using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol.Types; diff --git a/src/ModelContextProtocol/Protocol/Types/PromptsCapability.cs b/src/ModelContextProtocol/Protocol/Types/PromptsCapability.cs index 0c5ff3f8..8173f002 100644 --- a/src/ModelContextProtocol/Protocol/Types/PromptsCapability.cs +++ b/src/ModelContextProtocol/Protocol/Types/PromptsCapability.cs @@ -1,5 +1,4 @@ -using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Server; +using ModelContextProtocol.Server; using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol.Types; diff --git a/src/ModelContextProtocol/Protocol/Types/ResourcesCapability.cs b/src/ModelContextProtocol/Protocol/Types/ResourcesCapability.cs index 94e3ec57..98afc35e 100644 --- a/src/ModelContextProtocol/Protocol/Types/ResourcesCapability.cs +++ b/src/ModelContextProtocol/Protocol/Types/ResourcesCapability.cs @@ -1,5 +1,4 @@ -using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Server; +using ModelContextProtocol.Server; using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol.Types; diff --git a/src/ModelContextProtocol/Protocol/Types/RootsCapability.cs b/src/ModelContextProtocol/Protocol/Types/RootsCapability.cs index df9716ca..6c059549 100644 --- a/src/ModelContextProtocol/Protocol/Types/RootsCapability.cs +++ b/src/ModelContextProtocol/Protocol/Types/RootsCapability.cs @@ -1,6 +1,4 @@ -using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Server; -using System.Text.Json.Serialization; +using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol.Types; diff --git a/src/ModelContextProtocol/Protocol/Types/SamplingCapability.cs b/src/ModelContextProtocol/Protocol/Types/SamplingCapability.cs index 5d9ff94f..82111377 100644 --- a/src/ModelContextProtocol/Protocol/Types/SamplingCapability.cs +++ b/src/ModelContextProtocol/Protocol/Types/SamplingCapability.cs @@ -1,6 +1,4 @@ -using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Server; -using System.Text.Json.Serialization; +using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol.Types; diff --git a/src/ModelContextProtocol/Protocol/Types/ToolsCapability.cs b/src/ModelContextProtocol/Protocol/Types/ToolsCapability.cs index 002ade0d..dd85fec7 100644 --- a/src/ModelContextProtocol/Protocol/Types/ToolsCapability.cs +++ b/src/ModelContextProtocol/Protocol/Types/ToolsCapability.cs @@ -1,5 +1,4 @@ -using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Server; +using ModelContextProtocol.Server; using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol.Types; diff --git a/src/ModelContextProtocol/Server/McpServerOptions.cs b/src/ModelContextProtocol/Server/McpServerOptions.cs index fd96ca1f..87a01235 100644 --- a/src/ModelContextProtocol/Server/McpServerOptions.cs +++ b/src/ModelContextProtocol/Server/McpServerOptions.cs @@ -1,6 +1,5 @@  using ModelContextProtocol.Protocol.Types; -using System.Text.Json.Serialization; namespace ModelContextProtocol.Server; diff --git a/src/ModelContextProtocol/Shared/McpSession.cs b/src/ModelContextProtocol/Shared/McpSession.cs index 0e94c056..f1a8b7a0 100644 --- a/src/ModelContextProtocol/Shared/McpSession.cs +++ b/src/ModelContextProtocol/Shared/McpSession.cs @@ -296,6 +296,24 @@ await _transport.SendMessageAsync(new JsonRpcResponse }, cancellationToken).ConfigureAwait(false); } + private CancellationTokenRegistration RegisterCancellation(CancellationToken cancellationToken, RequestId requestId) + { + if (!cancellationToken.CanBeCanceled) + { + return default; + } + + return cancellationToken.Register(static objState => + { + var state = (Tuple)objState!; + _ = state.Item1.SendMessageAsync(new JsonRpcNotification + { + Method = NotificationMethods.CancelledNotification, + Params = JsonSerializer.SerializeToNode(new CancelledNotification { RequestId = state.Item2 }, McpJsonUtilities.JsonContext.Default.CancelledNotification) + }); + }, Tuple.Create(this, requestId)); + } + public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) { Throw.IfNullOrWhiteSpace(method); @@ -320,6 +338,8 @@ public async Task SendRequestAsync(JsonRpcRequest request, Canc throw new McpException("Transport is not connected"); } + cancellationToken.ThrowIfCancellationRequested(); + Histogram durationMetric = _isServer ? s_serverRequestDuration : s_clientRequestDuration; string method = request.Method; @@ -357,9 +377,16 @@ public async Task SendRequestAsync(JsonRpcRequest request, Canc _logger.SendingRequest(EndpointName, request.Method); await _transport.SendMessageAsync(request, cancellationToken).ConfigureAwait(false); - _logger.RequestSentAwaitingResponse(EndpointName, request.Method, request.Id.ToString()); - var response = await tcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false); + + // Now that the request has been sent, register for cancellation. If we registered before, + // a cancellation request could arrive before the server knew about that request ID, in which + // case the server could ignore it. + IJsonRpcMessage? response; + using (var registration = RegisterCancellation(cancellationToken, request.Id)) + { + response = await tcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false); + } if (response is JsonRpcError error) { @@ -400,6 +427,8 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca throw new McpException("Transport is not connected"); } + cancellationToken.ThrowIfCancellationRequested(); + Histogram durationMetric = _isServer ? s_serverRequestDuration : s_clientRequestDuration; string method = GetMethodName(message); diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs index be60c90f..ec4dd04c 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs @@ -3,45 +3,32 @@ using Microsoft.Extensions.Logging; using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; -using ModelContextProtocol.Tests.Utils; using Moq; -using System.IO.Pipelines; using System.Text.Json; using System.Text.Json.Serialization.Metadata; using System.Threading.Channels; namespace ModelContextProtocol.Tests.Client; -public class McpClientExtensionsTests : LoggedTest +public class McpClientExtensionsTests : ClientServerTestBase { - private readonly Pipe _clientToServerPipe = new(); - private readonly Pipe _serverToClientPipe = new(); - private readonly ServiceProvider _serviceProvider; - private readonly CancellationTokenSource _cts; - private readonly IMcpServer _server; - private readonly Task _serverTask; - public McpClientExtensionsTests(ITestOutputHelper outputHelper) : base(outputHelper) { - ServiceCollection sc = new(); - sc.AddSingleton(LoggerFactory); - sc.AddMcpServer().WithStreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream()); + } + + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { for (int f = 0; f < 10; f++) { string name = $"Method{f}"; - sc.AddSingleton(McpServerTool.Create((int i) => $"{name} Result {i}", new() { Name = name })); + services.AddSingleton(McpServerTool.Create((int i) => $"{name} Result {i}", new() { Name = name })); } - sc.AddSingleton(McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)](string i) => $"{i} Result", new() { Name = "ValuesSetViaAttr" })); - sc.AddSingleton(McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)](string i) => $"{i} Result", new() { Name = "ValuesSetViaOptions", Destructive = true, OpenWorld = false, ReadOnly = true })); - _serviceProvider = sc.BuildServiceProvider(); + services.AddSingleton(McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)] (string i) => $"{i} Result", new() { Name = "ValuesSetViaAttr" })); + services.AddSingleton(McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)] (string i) => $"{i} Result", new() { Name = "ValuesSetViaOptions", Destructive = true, OpenWorld = false, ReadOnly = true })); - _server = _serviceProvider.GetRequiredService(); - _cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); - _serverTask = _server.RunAsync(cancellationToken: _cts.Token); } [Theory] @@ -218,30 +205,6 @@ public async Task CreateSamplingHandler_ShouldHandleResourceMessages() Assert.Equal("endTurn", result.StopReason); } - public async ValueTask DisposeAsync() - { - await _cts.CancelAsync(); - - _clientToServerPipe.Writer.Complete(); - _serverToClientPipe.Writer.Complete(); - - await _serverTask; - - await _serviceProvider.DisposeAsync(); - _cts.Dispose(); - } - - private async Task CreateMcpClientForServer() - { - return await McpClientFactory.CreateAsync( - new StreamClientTransport( - serverInput: _clientToServerPipe.Writer.AsStream(), - serverOutput: _serverToClientPipe.Reader.AsStream(), - LoggerFactory), - loggerFactory: LoggerFactory, - cancellationToken: TestContext.Current.CancellationToken); - } - [Fact] public async Task ListToolsAsync_AllToolsReturned() { @@ -377,7 +340,7 @@ public async Task AsClientLoggerProvider_MessagesSentToClient() { IMcpClient client = await CreateMcpClientForServer(); - ILoggerProvider loggerProvider = _server.AsClientLoggerProvider(); + ILoggerProvider loggerProvider = Server.AsClientLoggerProvider(); Assert.Throws("categoryName", () => loggerProvider.CreateLogger(null!)); ILogger logger = loggerProvider.CreateLogger("TestLogger"); @@ -385,7 +348,7 @@ public async Task AsClientLoggerProvider_MessagesSentToClient() Assert.Null(logger.BeginScope("")); - Assert.Null(_server.LoggingLevel); + Assert.Null(Server.LoggingLevel); Assert.False(logger.IsEnabled(LogLevel.Trace)); Assert.False(logger.IsEnabled(LogLevel.Debug)); Assert.False(logger.IsEnabled(LogLevel.Information)); @@ -396,13 +359,13 @@ public async Task AsClientLoggerProvider_MessagesSentToClient() await client.SetLoggingLevel(LoggingLevel.Info, TestContext.Current.CancellationToken); DateTime start = DateTime.UtcNow; - while (_server.LoggingLevel is null) + while (Server.LoggingLevel is null) { await Task.Delay(1, TestContext.Current.CancellationToken); Assert.True(DateTime.UtcNow - start < TimeSpan.FromSeconds(10), "Timed out waiting for logging level to be set"); } - Assert.Equal(LoggingLevel.Info, _server.LoggingLevel); + Assert.Equal(LoggingLevel.Info, Server.LoggingLevel); Assert.False(logger.IsEnabled(LogLevel.Trace)); Assert.False(logger.IsEnabled(LogLevel.Debug)); Assert.True(logger.IsEnabled(LogLevel.Information)); diff --git a/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs b/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs new file mode 100644 index 00000000..187d05ea --- /dev/null +++ b/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs @@ -0,0 +1,75 @@ +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; +using System.IO.Pipelines; + +namespace ModelContextProtocol.Tests; + +public abstract class ClientServerTestBase : LoggedTest, IAsyncDisposable +{ + private readonly Pipe _clientToServerPipe = new(); + private readonly Pipe _serverToClientPipe = new(); + private readonly IMcpServerBuilder _builder; + private readonly CancellationTokenSource _cts; + private readonly Task _serverTask; + + public ClientServerTestBase(ITestOutputHelper testOutputHelper) + : base(testOutputHelper) + { + ServiceCollection sc = new(); + sc.AddSingleton(LoggerFactory); + _builder = sc + .AddMcpServer() + .WithStreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream()); + ConfigureServices(sc, _builder); + ServiceProvider = sc.BuildServiceProvider(); + + _cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); + Server = ServiceProvider.GetRequiredService(); + _serverTask = Server.RunAsync(_cts.Token); + } + + protected IMcpServer Server { get; } + + protected IServiceProvider ServiceProvider { get; } + + protected virtual void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { + } + + public async ValueTask DisposeAsync() + { + await _cts.CancelAsync(); + + _clientToServerPipe.Writer.Complete(); + _serverToClientPipe.Writer.Complete(); + + await _serverTask; + + if (ServiceProvider is IAsyncDisposable asyncDisposable) + { + await asyncDisposable.DisposeAsync(); + } + else if (ServiceProvider is IDisposable disposable) + { + disposable.Dispose(); + } + + _cts.Dispose(); + Dispose(); + } + + protected async Task CreateMcpClientForServer(McpClientOptions? options = null) + { + return await McpClientFactory.CreateAsync( + new StreamClientTransport( + serverInput: _clientToServerPipe.Writer.AsStream(), + _serverToClientPipe.Reader.AsStream(), + LoggerFactory), + loggerFactory: LoggerFactory, + cancellationToken: TestContext.Current.CancellationToken); + } +} diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs index a6069f59..aeef9d87 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs @@ -3,133 +3,92 @@ using Microsoft.Extensions.Options; using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; -using ModelContextProtocol.Tests.Utils; using System.ComponentModel; -using System.IO.Pipelines; using System.Threading.Channels; #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously namespace ModelContextProtocol.Tests.Configuration; -public class McpServerBuilderExtensionsPromptsTests : LoggedTest, IAsyncDisposable +public class McpServerBuilderExtensionsPromptsTests : ClientServerTestBase { - private readonly Pipe _clientToServerPipe = new(); - private readonly Pipe _serverToClientPipe = new(); - private readonly ServiceProvider _serviceProvider; - private readonly IMcpServerBuilder _builder; - private readonly CancellationTokenSource _cts; - private readonly Task _serverTask; - public McpServerBuilderExtensionsPromptsTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { - ServiceCollection sc = new(); - sc.AddSingleton(LoggerFactory); - _builder = sc - .AddMcpServer() - .WithStreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream()) - .WithListPromptsHandler(async (request, cancellationToken) => - { - var cursor = request.Params?.Cursor; - switch (cursor) - { - case null: - return new() + } + + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { + mcpServerBuilder + .WithListPromptsHandler(async (request, cancellationToken) => + { + var cursor = request.Params?.Cursor; + switch (cursor) { - NextCursor = "abc", - Prompts = [new() - { - Name = "FirstCustomPrompt", - Description = "First prompt returned by custom handler", - }], - }; - - case "abc": - return new() + case null: + return new() + { + NextCursor = "abc", + Prompts = [new() { - NextCursor = "def", - Prompts = [new() - { - Name = "SecondCustomPrompt", - Description = "Second prompt returned by custom handler", - }], - }; - - case "def": - return new() + Name = "FirstCustomPrompt", + Description = "First prompt returned by custom handler", + }], + }; + + case "abc": + return new() + { + NextCursor = "def", + Prompts = [new() { - NextCursor = null, - Prompts = [new() - { - Name = "FinalCustomPrompt", - Description = "Final prompt returned by custom handler", - }], - }; - - default: - throw new Exception("Unexpected cursor"); - } - }) - .WithGetPromptHandler(async (request, cancellationToken) => - { - switch (request.Params?.Name) - { - case "FirstCustomPrompt": - case "SecondCustomPrompt": - case "FinalCustomPrompt": - return new GetPromptResult() + Name = "SecondCustomPrompt", + Description = "Second prompt returned by custom handler", + }], + }; + + case "def": + return new() + { + NextCursor = null, + Prompts = [new() { - Messages = [new() { Role = Role.User, Content = new() { Text = $"hello from {request.Params.Name}", Type = "text" } }], - }; - - default: - throw new Exception($"Unknown prompt '{request.Params?.Name}'"); - } - }) - .WithPrompts(); - - - sc.AddSingleton(new ObjectWithId()); - _serviceProvider = sc.BuildServiceProvider(); - - var server = _serviceProvider.GetRequiredService(); - _cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); - _serverTask = server.RunAsync(cancellationToken: _cts.Token); - } - - public async ValueTask DisposeAsync() - { - await _cts.CancelAsync(); - - _clientToServerPipe.Writer.Complete(); - _serverToClientPipe.Writer.Complete(); - - await _serverTask; - - await _serviceProvider.DisposeAsync(); - _cts.Dispose(); - Dispose(); - } - - private async Task CreateMcpClientForServer(McpClientOptions? options = null) - { - return await McpClientFactory.CreateAsync( - new StreamClientTransport( - serverInput: _clientToServerPipe.Writer.AsStream(), - serverOutput: _serverToClientPipe.Reader.AsStream(), - LoggerFactory), - loggerFactory: LoggerFactory, - cancellationToken: TestContext.Current.CancellationToken); + Name = "FinalCustomPrompt", + Description = "Final prompt returned by custom handler", + }], + }; + + default: + throw new Exception("Unexpected cursor"); + } + }) + .WithGetPromptHandler(async (request, cancellationToken) => + { + switch (request.Params?.Name) + { + case "FirstCustomPrompt": + case "SecondCustomPrompt": + case "FinalCustomPrompt": + return new GetPromptResult() + { + Messages = [new() { Role = Role.User, Content = new() { Text = $"hello from {request.Params.Name}", Type = "text" } }], + }; + + default: + throw new Exception($"Unknown prompt '{request.Params?.Name}'"); + } + }) + .WithPrompts(); + + services.AddSingleton(new ObjectWithId()); } [Fact] public void Adds_Prompts_To_Server() { - var serverOptions = _serviceProvider.GetRequiredService>().Value; + var serverOptions = ServiceProvider.GetRequiredService>().Value; var prompts = serverOptions?.Capabilities?.Prompts?.PromptCollection; Assert.NotNull(prompts); Assert.NotEmpty(prompts); @@ -176,7 +135,7 @@ public async Task Can_Be_Notified_Of_Prompt_Changes() var notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken); Assert.False(notificationRead.IsCompleted); - var serverOptions = _serviceProvider.GetRequiredService>().Value; + var serverOptions = ServiceProvider.GetRequiredService>().Value; var serverPrompts = serverOptions.Capabilities?.Prompts?.PromptCollection; Assert.NotNull(serverPrompts); @@ -242,7 +201,9 @@ public async Task Throws_Exception_Missing_Parameter() [Fact] public void WithPrompts_InvalidArgs_Throws() { - Assert.Throws("promptTypes", () => _builder.WithPrompts((IEnumerable)null!)); + IMcpServerBuilder builder = new ServiceCollection().AddMcpServer(); + + Assert.Throws("promptTypes", () => builder.WithPrompts((IEnumerable)null!)); IMcpServerBuilder nullBuilder = null!; Assert.Throws("builder", () => nullBuilder.WithPrompts()); @@ -253,9 +214,11 @@ public void WithPrompts_InvalidArgs_Throws() [Fact] public void Empty_Enumerables_Is_Allowed() { - _builder.WithPrompts(promptTypes: []); // no exception - _builder.WithPrompts(); // no exception even though no prompts exposed - _builder.WithPromptsFromAssembly(typeof(AIFunction).Assembly); // no exception even though no prompts exposed + IMcpServerBuilder builder = new ServiceCollection().AddMcpServer(); + + builder.WithPrompts(promptTypes: []); // no exception + builder.WithPrompts(); // no exception even though no prompts exposed + builder.WithPromptsFromAssembly(typeof(AIFunction).Assembly); // no exception even though no prompts exposed } [Fact] diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 10d79c7f..a93a3090 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -7,7 +7,6 @@ using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; -using ModelContextProtocol.Tests.Utils; using System.Collections.Concurrent; using System.ComponentModel; using System.IO.Pipelines; @@ -19,23 +18,16 @@ namespace ModelContextProtocol.Tests.Configuration; -public class McpServerBuilderExtensionsToolsTests : LoggedTest, IAsyncDisposable +public class McpServerBuilderExtensionsToolsTests : ClientServerTestBase { - private readonly Pipe _clientToServerPipe = new(); - private readonly Pipe _serverToClientPipe = new(); - private readonly ServiceProvider _serviceProvider; - private readonly IMcpServerBuilder _builder; - private readonly CancellationTokenSource _cts; - private readonly Task _serverTask; - public McpServerBuilderExtensionsToolsTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { - ServiceCollection sc = new(); - sc.AddSingleton(LoggerFactory); - _builder = sc - .AddMcpServer() - .WithStreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream()) + } + + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { + mcpServerBuilder .WithListToolsHandler(async (request, cancellationToken) => { var cursor = request.Params?.Cursor; @@ -46,17 +38,17 @@ public McpServerBuilderExtensionsToolsTests(ITestOutputHelper testOutputHelper) { NextCursor = "abc", Tools = [new() - { - Name = "FirstCustomTool", - Description = "First tool returned by custom handler", - InputSchema = JsonSerializer.Deserialize(""" - { - "type": "object", - "properties": {}, - "required": [] - } - """), - }], + { + Name = "FirstCustomTool", + Description = "First tool returned by custom handler", + InputSchema = JsonSerializer.Deserialize(""" + { + "type": "object", + "properties": {}, + "required": [] + } + """), + }], }; case "abc": @@ -64,17 +56,17 @@ public McpServerBuilderExtensionsToolsTests(ITestOutputHelper testOutputHelper) { NextCursor = "def", Tools = [new() - { - Name = "SecondCustomTool", - Description = "Second tool returned by custom handler", - InputSchema = JsonSerializer.Deserialize(""" - { - "type": "object", - "properties": {}, - "required": [] - } - """), - }], + { + Name = "SecondCustomTool", + Description = "Second tool returned by custom handler", + InputSchema = JsonSerializer.Deserialize(""" + { + "type": "object", + "properties": {}, + "required": [] + } + """), + }], }; case "def": @@ -82,17 +74,17 @@ public McpServerBuilderExtensionsToolsTests(ITestOutputHelper testOutputHelper) { NextCursor = null, Tools = [new() - { - Name = "FinalCustomTool", - Description = "Third tool returned by custom handler", - InputSchema = JsonSerializer.Deserialize(""" - { - "type": "object", - "properties": {}, - "required": [] - } - """), - }], + { + Name = "FinalCustomTool", + Description = "Third tool returned by custom handler", + InputSchema = JsonSerializer.Deserialize(""" + { + "type": "object", + "properties": {}, + "required": [] + } + """), + }], }; default: @@ -117,43 +109,13 @@ public McpServerBuilderExtensionsToolsTests(ITestOutputHelper testOutputHelper) }) .WithTools(); - sc.AddSingleton(new ObjectWithId()); - _serviceProvider = sc.BuildServiceProvider(); - - var server = _serviceProvider.GetRequiredService(); - _cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); - _serverTask = server.RunAsync(cancellationToken: _cts.Token); - } - - public async ValueTask DisposeAsync() - { - await _cts.CancelAsync(); - - _clientToServerPipe.Writer.Complete(); - _serverToClientPipe.Writer.Complete(); - - await _serverTask; - - await _serviceProvider.DisposeAsync(); - _cts.Dispose(); - Dispose(); - } - - private async Task CreateMcpClientForServer(McpClientOptions? options = null) - { - return await McpClientFactory.CreateAsync( - new StreamClientTransport( - serverInput: _clientToServerPipe.Writer.AsStream(), - _serverToClientPipe.Reader.AsStream(), - LoggerFactory), - loggerFactory: LoggerFactory, - cancellationToken: TestContext.Current.CancellationToken); + services.AddSingleton(new ObjectWithId()); } [Fact] public void Adds_Tools_To_Server() { - var serverOptions = _serviceProvider.GetRequiredService>().Value; + var serverOptions = ServiceProvider.GetRequiredService>().Value; var tools = serverOptions.Capabilities?.Tools?.ToolCollection; Assert.NotNull(tools); Assert.NotEmpty(tools); @@ -183,8 +145,8 @@ public async Task Can_List_Registered_Tools() [Fact] public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_Tools() { - var options = _serviceProvider.GetRequiredService>().Value; - var loggerFactory = _serviceProvider.GetRequiredService(); + var options = ServiceProvider.GetRequiredService>().Value; + var loggerFactory = ServiceProvider.GetRequiredService(); for (int i = 0; i < 2; i++) { @@ -192,7 +154,7 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T var stdoutPipe = new Pipe(); await using var transport = new StreamServerTransport(stdinPipe.Reader.AsStream(), stdoutPipe.Writer.AsStream()); - await using var server = McpServerFactory.Create(transport, options, loggerFactory, _serviceProvider); + await using var server = McpServerFactory.Create(transport, options, loggerFactory, ServiceProvider); var serverRunTask = server.RunAsync(TestContext.Current.CancellationToken); await using (var client = await McpClientFactory.CreateAsync( @@ -237,7 +199,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes() var notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken); Assert.False(notificationRead.IsCompleted); - var serverOptions = _serviceProvider.GetRequiredService>().Value; + var serverOptions = ServiceProvider.GetRequiredService>().Value; var serverTools = serverOptions.Capabilities?.Tools?.ToolCollection; Assert.NotNull(serverTools); @@ -444,7 +406,9 @@ public async Task Returns_IsError_Missing_Parameter() [Fact] public void WithTools_InvalidArgs_Throws() { - Assert.Throws("toolTypes", () => _builder.WithTools((IEnumerable)null!)); + IMcpServerBuilder builder = new ServiceCollection().AddMcpServer(); + + Assert.Throws("toolTypes", () => builder.WithTools((IEnumerable)null!)); IMcpServerBuilder nullBuilder = null!; Assert.Throws("builder", () => nullBuilder.WithTools()); @@ -455,9 +419,11 @@ public void WithTools_InvalidArgs_Throws() [Fact] public void Empty_Enumerables_Is_Allowed() { - _builder.WithTools(toolTypes: []); // no exception - _builder.WithTools(); // no exception even though no tools exposed - _builder.WithToolsFromAssembly(typeof(AIFunction).Assembly); // no exception even though no tools exposed + IMcpServerBuilder builder = new ServiceCollection().AddMcpServer(); + + builder.WithTools(toolTypes: []); // no exception + builder.WithTools(); // no exception even though no tools exposed + builder.WithToolsFromAssembly(typeof(AIFunction).Assembly); // no exception even though no tools exposed } [Fact] diff --git a/tests/ModelContextProtocol.Tests/Protocol/CancellationTests.cs b/tests/ModelContextProtocol.Tests/Protocol/CancellationTests.cs new file mode 100644 index 00000000..c90e78da --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Protocol/CancellationTests.cs @@ -0,0 +1,69 @@ +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Server; + +namespace ModelContextProtocol.Tests; + +public class CancellationTests : ClientServerTestBase +{ + public CancellationTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper) + { + } + + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { + services.AddSingleton(McpServerTool.Create(WaitForCancellation)); + } + + private static async Task WaitForCancellation(CancellationToken cancellationToken) + { + try + { + await Task.Delay(-1, cancellationToken); + throw new InvalidOperationException("Unexpected completion without exception"); + } + catch (OperationCanceledException) + { + return; + } + } + + [Fact] + public async Task PrecancelRequest_CancelsBeforeSending() + { + var client = await CreateMcpClientForServer(); + + bool gotCancellation = false; + await using (Server.RegisterNotificationHandler(NotificationMethods.CancelledNotification, (notification, cancellationToken) => + { + gotCancellation = true; + return Task.CompletedTask; + })) + { + await Assert.ThrowsAsync(() => client.ListToolsAsync(cancellationToken: new CancellationToken(true))); + } + + Assert.False(gotCancellation); + } + + [Fact] + public async Task CancellationPropagation_RequestingCancellationCancelsPendingRequest() + { + var client = await CreateMcpClientForServer(); + + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + var waitTool = tools.First(t => t.Name == nameof(WaitForCancellation)); + + CancellationTokenSource cts = new(); + var waitTask = waitTool.InvokeAsync(cancellationToken: cts.Token); + Assert.False(waitTask.IsCompleted); + + await Task.Delay(1, TestContext.Current.CancellationToken); + Assert.False(waitTask.IsCompleted); + + cts.Cancel(); + await Assert.ThrowsAnyAsync(async () => await waitTask); + } +} diff --git a/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs index 5e2500ac..9ebd1e38 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs @@ -1,61 +1,12 @@ -using Microsoft.Extensions.AI; -using Microsoft.Extensions.DependencyInjection; -using ModelContextProtocol.Client; -using ModelContextProtocol.Protocol.Transport; -using ModelContextProtocol.Server; -using ModelContextProtocol.Tests.Utils; -using System.IO.Pipelines; +using ModelContextProtocol.Client; namespace ModelContextProtocol.Tests; -public class NotificationHandlerTests : LoggedTest, IAsyncDisposable +public class NotificationHandlerTests : ClientServerTestBase { - private readonly Pipe _clientToServerPipe = new(); - private readonly Pipe _serverToClientPipe = new(); - private readonly ServiceProvider _serviceProvider; - private readonly IMcpServerBuilder _builder; - private readonly CancellationTokenSource _cts; - private readonly Task _serverTask; - private readonly IMcpServer _server; - public NotificationHandlerTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { - ServiceCollection sc = new(); - sc.AddSingleton(LoggerFactory); - _builder = sc - .AddMcpServer() - .WithStreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream()); - _serviceProvider = sc.BuildServiceProvider(); - - _cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); - _server = _serviceProvider.GetRequiredService(); - _serverTask = _server.RunAsync(_cts.Token); - } - - public async ValueTask DisposeAsync() - { - await _cts.CancelAsync(); - - _clientToServerPipe.Writer.Complete(); - _serverToClientPipe.Writer.Complete(); - - await _serverTask; - - await _serviceProvider.DisposeAsync(); - _cts.Dispose(); - Dispose(); - } - - private async Task CreateMcpClientForServer(McpClientOptions? options = null) - { - return await McpClientFactory.CreateAsync( - new StreamClientTransport( - serverInput: _clientToServerPipe.Writer.AsStream(), - _serverToClientPipe.Reader.AsStream(), - LoggerFactory), - loggerFactory: LoggerFactory, - cancellationToken: TestContext.Current.CancellationToken); } [Fact] @@ -77,7 +28,7 @@ public async Task RegistrationsAreRemovedWhenDisposed() return Task.CompletedTask; })) { - await _server.SendNotificationAsync(NotificationName, TestContext.Current.CancellationToken); + await Server.SendNotificationAsync(NotificationName, TestContext.Current.CancellationToken); await tcs.Task; } } @@ -113,7 +64,7 @@ public async Task MultipleRegistrationsResultInMultipleCallbacks() try { - await _server.SendNotificationAsync(NotificationName, TestContext.Current.CancellationToken); + await Server.SendNotificationAsync(NotificationName, TestContext.Current.CancellationToken); await tcs.Task; } finally @@ -153,7 +104,7 @@ public async Task MultipleHandlersRunEvenIfOneThrows() try { - await _server.SendNotificationAsync(NotificationName, TestContext.Current.CancellationToken); + await Server.SendNotificationAsync(NotificationName, TestContext.Current.CancellationToken); await tcs.Task; } finally @@ -182,7 +133,7 @@ public async Task DisposeAsyncDoesNotCompleteWhileNotificationHandlerRuns(int nu await releaseHandler.Task; }); - await _server.SendNotificationAsync(NotificationName, TestContext.Current.CancellationToken); + await Server.SendNotificationAsync(NotificationName, TestContext.Current.CancellationToken); await handlerRunning.Task; var disposals = new ValueTask[numberOfDisposals]; @@ -231,7 +182,7 @@ public async Task DisposeAsyncCompletesImmediatelyWhenInvokedFromHandler(int num handlerRunning.SetResult(true); }); - await _server.SendNotificationAsync(NotificationName, TestContext.Current.CancellationToken); + await Server.SendNotificationAsync(NotificationName, TestContext.Current.CancellationToken); await handlerRunning.Task; } } diff --git a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs index b769d2c0..b7baa8bd 100644 --- a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs @@ -1,6 +1,5 @@ using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Routing; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging;