From e0318dd6f3746fa434925abcb71ab1a23def932d Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Wed, 2 Apr 2025 04:36:14 -0700 Subject: [PATCH 01/11] Add route pattern parameter to MapMcp --- .../McpEndpointRouteBuilderExtensions.cs | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs index 891408d0..39c7cb06 100644 --- a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs @@ -1,6 +1,7 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Routing; +using Microsoft.AspNetCore.Routing.Patterns; using Microsoft.AspNetCore.WebUtilities; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; @@ -10,6 +11,7 @@ using ModelContextProtocol.Server; using ModelContextProtocol.Utils.Json; using System.Collections.Concurrent; +using System.Diagnostics.CodeAnalysis; using System.Security.Cryptography; namespace Microsoft.AspNetCore.Builder; @@ -23,16 +25,27 @@ public static class McpEndpointRouteBuilderExtensions /// Sets up endpoints for handling MCP HTTP Streaming transport. /// /// The web application to attach MCP HTTP endpoints. + /// The route pattern prefix to map to. /// Provides an optional asynchronous callback for handling new MCP sessions. /// Returns a builder for configuring additional endpoint conventions like authorization policies. - public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints, Func? runSession = null) + public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints, [StringSyntax("Route")] string pattern = "", Func? runSession = null) + => endpoints.MapMcp(RoutePatternFactory.Parse(pattern), runSession); + + /// + /// Sets up endpoints for handling MCP HTTP Streaming transport. + /// + /// The web application to attach MCP HTTP endpoints. + /// The route pattern prefix to map to. + /// Provides an optional asynchronous callback for handling new MCP sessions. + /// Returns a builder for configuring additional endpoint conventions like authorization policies. + public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints, RoutePattern pattern, Func? runSession = null) { ConcurrentDictionary _sessions = new(StringComparer.Ordinal); var loggerFactory = endpoints.ServiceProvider.GetRequiredService(); var mcpServerOptions = endpoints.ServiceProvider.GetRequiredService>(); - var routeGroup = endpoints.MapGroup(""); + var routeGroup = endpoints.MapGroup(pattern); routeGroup.MapGet("/sse", async context => { From 31593c141ffbebb21af81d673a846cafe44571a3 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Wed, 2 Apr 2025 04:51:07 -0700 Subject: [PATCH 02/11] Add configureOptionsAsync --- .../McpEndpointRouteBuilderExtensions.cs | 34 ++++++++++++++----- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs index 39c7cb06..995ba53d 100644 --- a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs @@ -26,24 +26,33 @@ public static class McpEndpointRouteBuilderExtensions /// /// The web application to attach MCP HTTP endpoints. /// The route pattern prefix to map to. - /// Provides an optional asynchronous callback for handling new MCP sessions. + /// Provides an optional asynchronous callback for handling new MCP sessions. + /// Configure per-session options. /// Returns a builder for configuring additional endpoint conventions like authorization policies. - public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints, [StringSyntax("Route")] string pattern = "", Func? runSession = null) - => endpoints.MapMcp(RoutePatternFactory.Parse(pattern), runSession); + public static IEndpointConventionBuilder MapMcp( + this IEndpointRouteBuilder endpoints, + [StringSyntax("Route")] string pattern = "", + Func? runSessionAsync = null, + Func? configureOptionsAsync = null) + => endpoints.MapMcp(RoutePatternFactory.Parse(pattern), runSessionAsync, configureOptionsAsync); /// /// Sets up endpoints for handling MCP HTTP Streaming transport. /// /// The web application to attach MCP HTTP endpoints. /// The route pattern prefix to map to. - /// Provides an optional asynchronous callback for handling new MCP sessions. + /// Provides an optional asynchronous callback for handling new MCP sessions. + /// Configure per-session options. /// Returns a builder for configuring additional endpoint conventions like authorization policies. - public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints, RoutePattern pattern, Func? runSession = null) + public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints, RoutePattern pattern, + Func? runSessionAsync = null, + Func? configureOptionsAsync = null) { ConcurrentDictionary _sessions = new(StringComparer.Ordinal); var loggerFactory = endpoints.ServiceProvider.GetRequiredService(); - var mcpServerOptions = endpoints.ServiceProvider.GetRequiredService>(); + var optionsSnapshot = endpoints.ServiceProvider.GetRequiredService>(); + var optionsFactory = endpoints.ServiceProvider.GetRequiredService>(); var routeGroup = endpoints.MapGroup(pattern); @@ -62,6 +71,13 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo throw new Exception($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created."); } + var options = optionsSnapshot.Value; + if (configureOptionsAsync is not null) + { + options = optionsFactory.Create(Options.DefaultName); + await configureOptionsAsync.Invoke(context, options, requestAborted); + } + try { // Make sure we disable all response buffering for SSE @@ -69,12 +85,12 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo context.Features.GetRequiredFeature().DisableBuffering(); var transportTask = transport.RunAsync(cancellationToken: requestAborted); - await using var server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider); + await using var server = McpServerFactory.Create(transport, options, loggerFactory, endpoints.ServiceProvider); try { - runSession ??= RunSession; - await runSession(context, server, requestAborted); + runSessionAsync ??= RunSession; + await runSessionAsync(context, server, requestAborted); } finally { From f098676158aa4119ac730d25bb574eb678006d02 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Sun, 6 Apr 2025 10:35:53 -0700 Subject: [PATCH 03/11] Use Kestrel for all in-memory HTTP tests --- .../McpEndpointRouteBuilderExtensions.cs | 31 +- .../Transport/SseClientSessionTransport.cs | 17 +- .../Program.cs | 22 +- .../DockerEverythingServerTests.cs | 123 ++++++ .../Server/MapMcpTests.cs | 19 + .../SseIntegrationTests.cs | 323 +++++++-------- .../SseServerIntegrationTestFixture.cs | 52 ++- .../SseServerIntegrationTests.cs | 22 +- .../SseResponseStreamTransportTests.cs | 27 ++ .../Utils/InMemoryTestSseServer.cs | 382 ------------------ .../Utils/KestrelInMemoryConnection.cs | 113 ++++++ .../Utils/KestrelInMemoryTest.cs | 34 ++ .../Utils/KestrelInMemoryTransport.cs | 50 +++ 13 files changed, 594 insertions(+), 621 deletions(-) create mode 100644 tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs create mode 100644 tests/ModelContextProtocol.Tests/Server/MapMcpTests.cs create mode 100644 tests/ModelContextProtocol.Tests/Transport/SseResponseStreamTransportTests.cs delete mode 100644 tests/ModelContextProtocol.Tests/Utils/InMemoryTestSseServer.cs create mode 100644 tests/ModelContextProtocol.Tests/Utils/KestrelInMemoryConnection.cs create mode 100644 tests/ModelContextProtocol.Tests/Utils/KestrelInMemoryTest.cs create mode 100644 tests/ModelContextProtocol.Tests/Utils/KestrelInMemoryTransport.cs diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs index 995ba53d..f6a2ffc5 100644 --- a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs @@ -26,27 +26,28 @@ public static class McpEndpointRouteBuilderExtensions /// /// The web application to attach MCP HTTP endpoints. /// The route pattern prefix to map to. - /// Provides an optional asynchronous callback for handling new MCP sessions. /// Configure per-session options. + /// Provides an optional asynchronous callback for handling new MCP sessions. /// Returns a builder for configuring additional endpoint conventions like authorization policies. public static IEndpointConventionBuilder MapMcp( this IEndpointRouteBuilder endpoints, [StringSyntax("Route")] string pattern = "", - Func? runSessionAsync = null, - Func? configureOptionsAsync = null) - => endpoints.MapMcp(RoutePatternFactory.Parse(pattern), runSessionAsync, configureOptionsAsync); + Func? configureOptionsAsync = null, + Func? runSessionAsync = null) + => endpoints.MapMcp(RoutePatternFactory.Parse(pattern), configureOptionsAsync, runSessionAsync); /// /// Sets up endpoints for handling MCP HTTP Streaming transport. /// /// The web application to attach MCP HTTP endpoints. /// The route pattern prefix to map to. - /// Provides an optional asynchronous callback for handling new MCP sessions. /// Configure per-session options. + /// Provides an optional asynchronous callback for handling new MCP sessions. /// Returns a builder for configuring additional endpoint conventions like authorization policies. - public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints, RoutePattern pattern, - Func? runSessionAsync = null, - Func? configureOptionsAsync = null) + public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints, + RoutePattern pattern, + Func? configureOptionsAsync = null, + Func? runSessionAsync = null) { ConcurrentDictionary _sessions = new(StringComparer.Ordinal); @@ -64,6 +65,10 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo response.Headers.ContentType = "text/event-stream"; response.Headers.CacheControl = "no-cache,no-store"; + // Make sure we disable all response buffering for SSE + context.Response.Headers.ContentEncoding = "identity"; + context.Features.GetRequiredFeature().DisableBuffering(); + var sessionId = MakeNewSessionId(); await using var transport = new SseResponseStreamTransport(response.Body, $"/message?sessionId={sessionId}"); if (!_sessions.TryAdd(sessionId, transport)) @@ -80,17 +85,15 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo try { - // Make sure we disable all response buffering for SSE - context.Response.Headers.ContentEncoding = "identity"; - context.Features.GetRequiredFeature().DisableBuffering(); - var transportTask = transport.RunAsync(cancellationToken: requestAborted); - await using var server = McpServerFactory.Create(transport, options, loggerFactory, endpoints.ServiceProvider); + await using var mcpServer = McpServerFactory.Create(transport, options, loggerFactory, endpoints.ServiceProvider); + + context.Features.Set(mcpServer); try { runSessionAsync ??= RunSession; - await runSessionAsync(context, server, requestAborted); + await runSessionAsync(context, mcpServer, requestAborted); } finally { diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs index 49b2fe40..b9624874 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs @@ -151,16 +151,14 @@ private async Task CloseAsync() { try { - if (!_connectionCts.IsCancellationRequested) - { - await _connectionCts.CancelAsync().ConfigureAwait(false); - _connectionCts.Dispose(); - } + await _connectionCts.CancelAsync().ConfigureAwait(false); if (_receiveTask != null) { await _receiveTask.ConfigureAwait(false); } + + _connectionCts.Dispose(); } finally { @@ -206,8 +204,15 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); - await foreach (SseItem sseEvent in SseParser.Create(stream).EnumerateAsync(cancellationToken).ConfigureAwait(false)) + // A lot of methods in HttpClient, including HttpConnection.FillAsync, do not currently respect any + // cancellation tokens so we have to cancel awaiting ourselves. + // I tried disposing the HttpRequestMessage and HttpResponseMessage instead, but neither of these canceled + // the ongoing body read, and we don't want to dispose an HttpClient we may not own. + var sseEventEnumerator = SseParser.Create(stream).EnumerateAsync(cancellationToken).GetAsyncEnumerator(); + while (await sseEventEnumerator.MoveNextAsync().AsTask().WaitAsync(cancellationToken).ConfigureAwait(false)) { + SseItem sseEvent = sseEventEnumerator.Current; + switch (sseEvent.EventType) { case "endpoint": diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index 43ae3207..4f704672 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -1,4 +1,5 @@ -using ModelContextProtocol.Protocol.Types; +using Microsoft.AspNetCore.Connections; +using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using Serilog; using System.Text; @@ -372,17 +373,24 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st }; } - public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvider = null, CancellationToken cancellationToken = default) + public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvider = null, IConnectionListenerFactory? kestrelTransport = null, CancellationToken cancellationToken = default) { Console.WriteLine("Starting server..."); - int port = args.Length > 0 && uint.TryParse(args[0], out var parsedPort) ? (int)parsedPort : 3001; - var builder = WebApplication.CreateSlimBuilder(args); - builder.WebHost.ConfigureKestrel(options => + + if (kestrelTransport is null) + { + int port = args.Length > 0 && uint.TryParse(args[0], out var parsedPort) ? (int)parsedPort : 3001; + builder.WebHost.ConfigureKestrel(options => + { + options.ListenLocalhost(port); + }); + } + else { - options.ListenLocalhost(port); - }); + builder.Services.AddSingleton(kestrelTransport); + } ConfigureSerilog(builder.Logging); if (loggerProvider is not null) diff --git a/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs b/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs new file mode 100644 index 00000000..3dbcfe7c --- /dev/null +++ b/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs @@ -0,0 +1,123 @@ +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Tests.Utils; + +namespace ModelContextProtocol.Tests; + +public class DockerEverythingServerTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper) +{ + /// Port number to be grabbed by the next test. + private static int s_nextPort = 3000; + + // If the tests run concurrently against different versions of the runtime, tests can conflict with + // each other in the ports set up for interacting with containers. Ensure that such suites running + // against different TFMs use different port numbers. + private static readonly int s_portOffset = 1000 * (Environment.Version.Major switch + { + int v when v >= 8 => Environment.Version.Major - 7, + _ => 0, + }); + + private static int CreatePortNumber() => Interlocked.Increment(ref s_nextPort) + s_portOffset; + + public static bool IsDockerAvailable => EverythingSseServerFixture.IsDockerAvailable; + + [Fact(Skip = "docker is not available", SkipUnless = nameof(IsDockerAvailable))] + [Trait("Execution", "Manual")] + public async Task ConnectAndReceiveMessage_EverythingServerWithSse() + { + int port = CreatePortNumber(); + + await using var fixture = new EverythingSseServerFixture(port); + await fixture.StartAsync(); + + var defaultOptions = new McpClientOptions + { + ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" } + }; + + var defaultConfig = new McpServerConfig + { + Id = "everything", + Name = "Everything", + TransportType = TransportTypes.Sse, + TransportOptions = [], + Location = $"http://localhost:{port}/sse" + }; + + // Create client and run tests + await using var client = await McpClientFactory.CreateAsync( + defaultConfig, + defaultOptions, + loggerFactory: LoggerFactory, + cancellationToken: TestContext.Current.CancellationToken); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + // assert + Assert.NotEmpty(tools); + } + + [Fact(Skip = "docker is not available", SkipUnless = nameof(IsDockerAvailable))] + [Trait("Execution", "Manual")] + public async Task Sampling_Sse_EverythingServer() + { + int port = CreatePortNumber(); + + await using var fixture = new EverythingSseServerFixture(port); + await fixture.StartAsync(); + + var defaultConfig = new McpServerConfig + { + Id = "everything", + Name = "Everything", + TransportType = TransportTypes.Sse, + TransportOptions = [], + Location = $"http://localhost:{port}/sse" + }; + + int samplingHandlerCalls = 0; + var defaultOptions = new McpClientOptions + { + Capabilities = new() + { + Sampling = new() + { + SamplingHandler = (_, _, _) => + { + samplingHandlerCalls++; + return Task.FromResult(new CreateMessageResult + { + Model = "test-model", + Role = "assistant", + Content = new Content + { + Type = "text", + Text = "Test response" + } + }); + }, + }, + }, + }; + + await using var client = await McpClientFactory.CreateAsync( + defaultConfig, + defaultOptions, + loggerFactory: LoggerFactory, + cancellationToken: TestContext.Current.CancellationToken); + + // Call the server's sampleLLM tool which should trigger our sampling handler + var result = await client.CallToolAsync("sampleLLM", new Dictionary + { + ["prompt"] = "Test prompt", + ["maxTokens"] = 100 + }, cancellationToken: TestContext.Current.CancellationToken); + + // assert + Assert.NotNull(result); + var textContent = Assert.Single(result.Content); + Assert.Equal("text", textContent.Type); + Assert.False(string.IsNullOrEmpty(textContent.Text)); + } +} diff --git a/tests/ModelContextProtocol.Tests/Server/MapMcpTests.cs b/tests/ModelContextProtocol.Tests/Server/MapMcpTests.cs new file mode 100644 index 00000000..f8a80d66 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/MapMcpTests.cs @@ -0,0 +1,19 @@ +using Microsoft.AspNetCore.Builder; +using ModelContextProtocol.Tests.Utils; + +namespace ModelContextProtocol.Tests.Server; + +public class MapMcpTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper) +{ + [Fact] + public async Task Allows_Customizing_Route() + { + await using var app = Builder.Build(); + app.MapMcp("/mcp"); + await app.StartAsync(TestContext.Current.CancellationToken); + + using var httpClient = CreateHttpClient(); + using var response = await httpClient.GetAsync("http://localhost/mcp/sse", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); + response.EnsureSuccessStatusCode(); + } +} diff --git a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs index 91e99898..fd5eddd6 100644 --- a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs @@ -1,238 +1,193 @@ -using Microsoft.Extensions.Logging; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; -using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Utils; +using ModelContextProtocol.Utils.Json; namespace ModelContextProtocol.Tests; -public class SseIntegrationTests(ITestOutputHelper outputHelper) : LoggedTest(outputHelper) +public class SseIntegrationTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper) { - /// Port number to be grabbed by the next test. - private static int s_nextPort = 3000; - - // If the tests run concurrently against different versions of the runtime, tests can conflict with - // each other in the ports set up for interacting with containers. Ensure that such suites running - // against different TFMs use different port numbers. - private static readonly int s_portOffset = 1000 * (Environment.Version.Major switch + private McpServerConfig DefaultServerConfig = new() { - int v when v >= 8 => Environment.Version.Major - 7, - _ => 0, - }); + Id = "test_server", + Name = "In-memory Test Server", + TransportType = TransportTypes.Sse, + TransportOptions = [], + Location = $"http://localhost/sse" + }; + + private Task ConnectMcpClient(HttpClient httpClient, McpClientOptions? clientOptions = null) + => McpClientFactory.CreateAsync( + DefaultServerConfig, + clientOptions, + (_, _) => new SseClientTransport(new(), DefaultServerConfig, httpClient, LoggerFactory), + LoggerFactory, + TestContext.Current.CancellationToken); - private static int CreatePortNumber() => Interlocked.Increment(ref s_nextPort) + s_portOffset; [Fact] public async Task ConnectAndReceiveMessage_InMemoryServer() { - // Arrange - await using InMemoryTestSseServer server = new(CreatePortNumber(), LoggerFactory.CreateLogger()); - await server.StartAsync(); + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); - var defaultConfig = new McpServerConfig - { - Id = "test_server", - Name = "In-memory Test Server", - TransportType = TransportTypes.Sse, - TransportOptions = [], - Location = $"http://localhost:{server.Port}/sse" - }; - - // Act - await using var client = await McpClientFactory.CreateAsync( - defaultConfig, - loggerFactory: LoggerFactory, - cancellationToken: TestContext.Current.CancellationToken); - - // Wait for SSE connection to be established - await server.WaitForConnectionAsync(TimeSpan.FromSeconds(10)); + using var httpClient = CreateHttpClient(); + await using var mcpClient = await ConnectMcpClient(httpClient); // Send a test message through POST endpoint - await client.SendNotificationAsync("test/message", new { message = "Hello, SSE!" }, cancellationToken: TestContext.Current.CancellationToken); + await mcpClient.SendNotificationAsync("test/message", new { message = "Hello, SSE!" }, cancellationToken: TestContext.Current.CancellationToken); - // Assert Assert.True(true); } [Fact] - [Trait("Execution", "Manual")] - public async Task ConnectAndReceiveMessage_EverythingServerWithSse() + public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventUri() { - Assert.SkipWhen(!EverythingSseServerFixture.IsDockerAvailable, "docker is not available"); - - int port = CreatePortNumber(); + await using var app = Builder.Build(); + MapAbsoluteEndpointUriMcp(app); + await app.StartAsync(TestContext.Current.CancellationToken); - await using var fixture = new EverythingSseServerFixture(port); - await fixture.StartAsync(); + using var httpClient = CreateHttpClient(); + await using var mcpClient = await ConnectMcpClient(httpClient); - var defaultOptions = new McpClientOptions - { - ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" } - }; + // Send a test message through POST endpoint + await mcpClient.SendNotificationAsync("test/message", new { message = "Hello, SSE!" }, cancellationToken: TestContext.Current.CancellationToken); - var defaultConfig = new McpServerConfig - { - Id = "everything", - Name = "Everything", - TransportType = TransportTypes.Sse, - TransportOptions = [], - Location = $"http://localhost:{port}/sse" - }; - - // Create client and run tests - await using var client = await McpClientFactory.CreateAsync( - defaultConfig, - defaultOptions, - loggerFactory: LoggerFactory, - cancellationToken: TestContext.Current.CancellationToken); - var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); - - // assert - Assert.NotEmpty(tools); + Assert.True(true); } [Fact] - [Trait("Execution", "Manual")] - public async Task Sampling_Sse_EverythingServer() + public async Task ConnectAndReceiveNotification_InMemoryServer() { - Assert.SkipWhen(!EverythingSseServerFixture.IsDockerAvailable, "docker is not available"); + var receivedNotification = new TaskCompletionSource(); - int port = CreatePortNumber(); + await using var app = Builder.Build(); + app.MapMcp(configureOptionsAsync: (httpContext, mcpServerOptions, CancellationToken) => + { + mcpServerOptions.Capabilities = new() + { + NotificationHandlers = + [ + new("test/notification", async notification => + { + Assert.Equal("Hello from client!", notification.Params?["message"]?.GetValue()); - await using var fixture = new EverythingSseServerFixture(port); - await fixture.StartAsync(); + var server = httpContext.Features.GetRequiredFeature(); - var defaultConfig = new McpServerConfig - { - Id = "everything", - Name = "Everything", - TransportType = TransportTypes.Sse, - TransportOptions = [], - Location = $"http://localhost:{port}/sse" - }; - - int samplingHandlerCalls = 0; - var defaultOptions = new McpClientOptions + // REVIEW: Where is the CancellationToken for notification handlers? + // The httpContext.RequestAborted trick will not always work once we fully support the HTTP streaming spec. + await server.SendNotificationAsync("test/notification", new { message = "Hello from server!" }, cancellationToken: httpContext.RequestAborted); + }), + ], + }; + + return Task.CompletedTask; + }); + await app.StartAsync(TestContext.Current.CancellationToken); + + using var httpClient = CreateHttpClient(); + await using var mcpClient = await ConnectMcpClient(httpClient, new() { Capabilities = new() { - Sampling = new() + NotificationHandlers = [new("test/notification", args => { - SamplingHandler = (_, _, _) => - { - samplingHandlerCalls++; - return Task.FromResult(new CreateMessageResult - { - Model = "test-model", - Role = "assistant", - Content = new Content - { - Type = "text", - Text = "Test response" - } - }); - }, - }, - }, - }; + var msg = args.Params?["message"]?.GetValue(); + receivedNotification.SetResult(msg); - await using var client = await McpClientFactory.CreateAsync( - defaultConfig, - defaultOptions, - loggerFactory: LoggerFactory, - cancellationToken: TestContext.Current.CancellationToken); + return Task.CompletedTask; + })], + }, + }); - // Call the server's sampleLLM tool which should trigger our sampling handler - var result = await client.CallToolAsync("sampleLLM", new Dictionary + await using var mcpClient2 = await ConnectMcpClient(httpClient, new() + { + Capabilities = new() { - ["prompt"] = "Test prompt", - ["maxTokens"] = 100 - }, cancellationToken: TestContext.Current.CancellationToken); - - // assert - Assert.NotNull(result); - var textContent = Assert.Single(result.Content); - Assert.Equal("text", textContent.Type); - Assert.False(string.IsNullOrEmpty(textContent.Text)); - } - - [Fact] - public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventUri() - { - // Arrange - await using InMemoryTestSseServer server = new(CreatePortNumber(), LoggerFactory.CreateLogger()); - server.UseFullUrlForEndpointEvent = true; - await server.StartAsync(); + NotificationHandlers = [new("test/notification", args => + { + var msg = args.Params?["message"]?.GetValue(); + receivedNotification.SetResult(msg); - var defaultConfig = new McpServerConfig - { - Id = "test_server", - Name = "In-memory Test Server", - TransportType = TransportTypes.Sse, - TransportOptions = [], - Location = $"http://localhost:{server.Port}/sse" - }; - - // Act - await using var client = await McpClientFactory.CreateAsync( - defaultConfig, - loggerFactory: LoggerFactory, - cancellationToken: TestContext.Current.CancellationToken); - - // Wait for SSE connection to be established - await server.WaitForConnectionAsync(TimeSpan.FromSeconds(10)); + return Task.CompletedTask; + })], + }, + }); // Send a test message through POST endpoint - await client.SendNotificationAsync("test/message", new { message = "Hello, SSE!" }, cancellationToken: TestContext.Current.CancellationToken); + await mcpClient.SendNotificationAsync("test/notification", new { message = "Hello from client!" }, cancellationToken: TestContext.Current.CancellationToken); - // Assert - Assert.True(true); + var message = await receivedNotification.Task.WaitAsync(TimeSpan.FromSeconds(10), TestContext.Current.CancellationToken); + Assert.Equal("Hello from server!", message); } - [Fact] - public async Task ConnectAndReceiveNotification_InMemoryServer() + private static void MapAbsoluteEndpointUriMcp(IEndpointRouteBuilder endpoints) { - // Arrange - await using InMemoryTestSseServer server = new(CreatePortNumber(), LoggerFactory.CreateLogger()); - await server.StartAsync(); + var loggerFactory = endpoints.ServiceProvider.GetRequiredService(); + var optionsSnapshot = endpoints.ServiceProvider.GetRequiredService>(); - var defaultConfig = new McpServerConfig + var routeGroup = endpoints.MapGroup(""); + SseResponseStreamTransport? session = null; + + routeGroup.MapGet("/sse", async context => { - Id = "test_server", - Name = "In-memory Test Server", - TransportType = TransportTypes.Sse, - TransportOptions = [], - Location = $"http://localhost:{server.Port}/sse" - }; - - // Act - var receivedNotification = new TaskCompletionSource(); - await using var client = await McpClientFactory.CreateAsync( - defaultConfig, - new() - { - Capabilities = new() - { - NotificationHandlers = [new("test/notification", args => - { - var msg = args.Params?["message"]?.GetValue(); - receivedNotification.SetResult(msg); + var response = context.Response; + var requestAborted = context.RequestAborted; - return Task.CompletedTask; - })], - }, - }, - loggerFactory: LoggerFactory, - cancellationToken: TestContext.Current.CancellationToken); + response.Headers.ContentType = "text/event-stream"; - // Wait for SSE connection to be established - await server.WaitForConnectionAsync(TimeSpan.FromSeconds(10)); + await using var transport = new SseResponseStreamTransport(response.Body, "http://localhost/message"); + session = transport; - // Act - await server.SendTestNotificationAsync("Hello from server!"); + try + { + var transportTask = transport.RunAsync(cancellationToken: requestAborted); + await using var server = McpServerFactory.Create(transport, optionsSnapshot.Value, loggerFactory, endpoints.ServiceProvider); - // Assert - var message = await receivedNotification.Task.WaitAsync(TimeSpan.FromSeconds(10), TestContext.Current.CancellationToken); - Assert.Equal("Hello from server!", message); + try + { + await server.RunAsync(requestAborted); + } + finally + { + await transport.DisposeAsync(); + await transportTask; + } + } + catch (OperationCanceledException) when (requestAborted.IsCancellationRequested) + { + // RequestAborted always triggers when the client disconnects before a complete response body is written, + // but this is how SSE connections are typically closed. + } + }); + + routeGroup.MapPost("/message", async context => + { + if (session is null) + { + await Results.BadRequest("Session not started.").ExecuteAsync(context); + return; + } + var message = (IJsonRpcMessage?)await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)), context.RequestAborted); + if (message is null) + { + await Results.BadRequest("No message in request body.").ExecuteAsync(context); + return; + } + + await session.OnMessageReceivedAsync(message, context.RequestAborted); + context.Response.StatusCode = StatusCodes.Status202Accepted; + await context.Response.WriteAsync("Accepted"); + }); } } diff --git a/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs b/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs index 1720bc69..6b7b474e 100644 --- a/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs @@ -1,4 +1,6 @@ -using ModelContextProtocol.Protocol.Transport; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Test.Utils; using ModelContextProtocol.Tests.Utils; using ModelContextProtocol.TestSseServer; @@ -7,28 +9,52 @@ namespace ModelContextProtocol.Tests; public class SseServerIntegrationTestFixture : IAsyncDisposable { + private readonly KestrelInMemoryTransport _inMemoryTransport = new(); + private readonly Task _serverTask; private readonly CancellationTokenSource _stopCts = new(); + // XUnit's ITestOutputHelper is created per test, while this fixture is used for + // multiple tests, so this dispatches the output to the current test. private readonly DelegatingTestOutputHelper _delegatingTestOutputHelper = new(); - public McpServerConfig DefaultConfig { get; } + private McpServerConfig DefaultServerConfig { get; } = new McpServerConfig + { + Id = "test_server", + Name = "TestServer", + TransportType = TransportTypes.Sse, + TransportOptions = [], + Location = $"http://localhost/sse" + }; public SseServerIntegrationTestFixture() { - // Ensure that test suites running against different TFMs and possibly concurrently use different port numbers. - int port = 3001 + Environment.Version.Major; + var socketsHttpHandler = new SocketsHttpHandler() + { + ConnectCallback = (context, token) => + { + var connection = _inMemoryTransport.CreateConnection(); + return new(connection.ClientStream); + }, + }; - DefaultConfig = new McpServerConfig + HttpClient = new HttpClient(socketsHttpHandler) { - Id = "test_server", - Name = "TestServer", - TransportType = TransportTypes.Sse, - TransportOptions = [], - Location = $"http://localhost:{port}/sse" + BaseAddress = new Uri(DefaultServerConfig.Location), }; + _serverTask = Program.MainAsync([], new XunitLoggerProvider(_delegatingTestOutputHelper), _inMemoryTransport, _stopCts.Token); + } - _serverTask = Program.MainAsync([port.ToString()], new XunitLoggerProvider(_delegatingTestOutputHelper), _stopCts.Token); + public HttpClient HttpClient { get; } + + public Task ConnectMcpClientAsync(McpClientOptions? options, ILoggerFactory loggerFactory) + { + return McpClientFactory.CreateAsync( + DefaultServerConfig, + options, + (_, _) => new SseClientTransport(new(), DefaultServerConfig, HttpClient, loggerFactory), + loggerFactory, + TestContext.Current.CancellationToken); } public void Initialize(ITestOutputHelper output) @@ -44,7 +70,10 @@ public void TestCompleted() public async ValueTask DisposeAsync() { _delegatingTestOutputHelper.CurrentTestOutputHelper = null; + + HttpClient.Dispose(); _stopCts.Cancel(); + try { await _serverTask.ConfigureAwait(false); @@ -52,6 +81,7 @@ public async ValueTask DisposeAsync() catch (OperationCanceledException) { } + _stopCts.Dispose(); } } diff --git a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs index eaead4e0..7913bbb2 100644 --- a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs @@ -25,19 +25,9 @@ public override void Dispose() private Task GetClientAsync(McpClientOptions? options = null) { - return McpClientFactory.CreateAsync( - _fixture.DefaultConfig, - options, - loggerFactory: LoggerFactory, - cancellationToken: TestContext.Current.CancellationToken); + return _fixture.ConnectMcpClientAsync(options, LoggerFactory); } - private HttpClient GetHttpClient() => - new() - { - BaseAddress = new(_fixture.DefaultConfig.Location!), - }; - [Fact] public async Task ConnectAndPing_Sse_TestServer() { @@ -283,8 +273,7 @@ public async Task CallTool_Sse_EchoServer_Concurrently() [Fact] public async Task EventSourceResponse_Includes_ExpectedHeaders() { - using var httpClient = GetHttpClient(); - using var sseResponse = await httpClient.GetAsync("", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); + using var sseResponse = await _fixture.HttpClient.GetAsync("", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); sseResponse.EnsureSuccessStatusCode(); @@ -300,8 +289,7 @@ public async Task EventSourceStream_Includes_MessageEventType() { // Simulate our own MCP client handshake using a plain HttpClient so we can look for "event: message" // in the raw SSE response stream which is not exposed by the real MCP client. - using var httpClient = GetHttpClient(); - await using var sseResponse = await httpClient.GetStreamAsync("", TestContext.Current.CancellationToken); + await using var sseResponse = await _fixture.HttpClient.GetStreamAsync("", TestContext.Current.CancellationToken); using var streamReader = new StreamReader(sseResponse); var endpointEvent = await streamReader.ReadLineAsync(TestContext.Current.CancellationToken); @@ -317,7 +305,7 @@ public async Task EventSourceStream_Includes_MessageEventType() """; using (var initializeRequestBody = new StringContent(initializeRequest, Encoding.UTF8, "application/json")) { - var response = await httpClient.PostAsync(messageEndpoint, initializeRequestBody, TestContext.Current.CancellationToken); + var response = await _fixture.HttpClient.PostAsync(messageEndpoint, initializeRequestBody, TestContext.Current.CancellationToken); Assert.Equal(HttpStatusCode.Accepted, response.StatusCode); } @@ -326,7 +314,7 @@ public async Task EventSourceStream_Includes_MessageEventType() """; using (var initializedNotificationBody = new StringContent(initializedNotification, Encoding.UTF8, "application/json")) { - var response = await httpClient.PostAsync(messageEndpoint, initializedNotificationBody, TestContext.Current.CancellationToken); + var response = await _fixture.HttpClient.PostAsync(messageEndpoint, initializedNotificationBody, TestContext.Current.CancellationToken); Assert.Equal(HttpStatusCode.Accepted, response.StatusCode); } diff --git a/tests/ModelContextProtocol.Tests/Transport/SseResponseStreamTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseResponseStreamTransportTests.cs new file mode 100644 index 00000000..220746cd --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Transport/SseResponseStreamTransportTests.cs @@ -0,0 +1,27 @@ +using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Tests.Utils; +using System.IO.Pipelines; + +namespace ModelContextProtocol.Tests.Transport; + +public class SseResponseStreamTransportTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper) +{ + [Fact] + public async Task Can_Customize_MessageEndpoint() + { + var responsePipe = new Pipe(); + + await using var transport = new SseResponseStreamTransport(responsePipe.Writer.AsStream(), "/my-message-endpoint"); + var transportRunTask = transport.RunAsync(TestContext.Current.CancellationToken); + + using var responseStreamReader = new StreamReader(responsePipe.Reader.AsStream()); + var firstLine = await responseStreamReader.ReadLineAsync(TestContext.Current.CancellationToken); + Assert.Equal("event: endpoint", firstLine); + + var secondLine = await responseStreamReader.ReadLineAsync(TestContext.Current.CancellationToken); + Assert.Equal("data: /my-message-endpoint", secondLine); + + responsePipe.Reader.Complete(); + responsePipe.Writer.Complete(); + } +} diff --git a/tests/ModelContextProtocol.Tests/Utils/InMemoryTestSseServer.cs b/tests/ModelContextProtocol.Tests/Utils/InMemoryTestSseServer.cs deleted file mode 100644 index 7d7122a8..00000000 --- a/tests/ModelContextProtocol.Tests/Utils/InMemoryTestSseServer.cs +++ /dev/null @@ -1,382 +0,0 @@ -using System.Collections.Concurrent; -using System.Net; -using System.Text.Json; -using ModelContextProtocol.Protocol.Messages; -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; -using ModelContextProtocol.Protocol.Types; - -namespace ModelContextProtocol.Tests.Utils; - -public sealed class InMemoryTestSseServer : IAsyncDisposable -{ - private readonly HttpListener _listener; - private readonly CancellationTokenSource _cts; - private readonly ILogger _logger; - private Task? _serverTask; - private readonly TaskCompletionSource _connectionEstablished = new(); - - // SSE endpoint for GET - private readonly string _endpointPath; - // POST endpoint - private readonly string _messagePath; - - // Keep track of all open SSE connections (StreamWriters). - private readonly ConcurrentBag _sseClients = []; - - public InMemoryTestSseServer(int port = 5000, ILogger? logger = null) - { - Port = port; - - _listener = new HttpListener(); - _listener.Prefixes.Add($"http://localhost:{port}/"); - _cts = new CancellationTokenSource(); - _logger = logger ?? NullLogger.Instance; - - _endpointPath = "/sse"; - _messagePath = "/message"; - } - - public int Port { get; } - - /// - /// This is to be able to use the full URL for the endpoint event. - /// - public bool UseFullUrlForEndpointEvent { get; set; } - - /// - /// Full URL for the SSE endpoint, e.g. "http://localhost:5000/sse" - /// - public string SseEndpoint - => $"http://localhost:{_listener.Prefixes.First().Split(':')[2].TrimEnd('/')}{_endpointPath}"; - - /// - /// Full URL for the message endpoint, e.g. "http://localhost:5000/message" - /// - public string MessageEndpoint - => $"http://localhost:{_listener.Prefixes.First().Split(':')[2].TrimEnd('/')}{_messagePath}"; - - /// - /// Starts the server so it can accept incoming connections and POST requests. - /// - public async Task StartAsync() - { - _listener.Start(); - _serverTask = HandleConnectionsAsync(_cts.Token); - - _logger.LogInformation("Test SSE server started on {Endpoint}", SseEndpoint); - await Task.CompletedTask; - } - - private async Task HandleConnectionsAsync(CancellationToken ct) - { - try - { - while (!ct.IsCancellationRequested) - { - var context = await _listener.GetContextAsync(); - _ = Task.Run(() => HandleRequestAsync(context, ct), ct); - } - } - catch (OperationCanceledException) - { - // Ignore, we are shutting down - } - catch (Exception ex) when (!ct.IsCancellationRequested) - { - _logger.LogError(ex, "Error in SSE server connection handling"); - } - } - - private async Task HandleRequestAsync(HttpListenerContext context, CancellationToken ct) - { - var request = context.Request; - var response = context.Response; - - // Handle SSE endpoint - if (request.HttpMethod.Equals("GET", StringComparison.OrdinalIgnoreCase) - && request.Url?.AbsolutePath.Equals(_endpointPath, StringComparison.OrdinalIgnoreCase) == true) - { - await HandleSseConnectionAsync(context, ct); - } - // Handle POST /message - else if (request.HttpMethod.Equals("POST", StringComparison.OrdinalIgnoreCase) - && request.Url?.AbsolutePath.Equals(_messagePath, StringComparison.OrdinalIgnoreCase) == true) - { - await HandlePostMessageAsync(context, ct); - } - else - { - response.StatusCode = 404; - response.Close(); - } - } - - /// - /// Handle Server-Sent Events (SSE) connection. - /// Send the initial event: endpoint with the full POST URL. - /// Keep the connection open until the server is disposed or the client disconnects. - /// - private async Task HandleSseConnectionAsync(HttpListenerContext context, CancellationToken ct) - { - var response = context.Response; - response.ContentType = "text/event-stream"; - response.Headers["Cache-Control"] = "no-cache"; - // Ensures the response is never chunked away by the framework - response.SendChunked = true; - response.StatusCode = (int)HttpStatusCode.OK; - - using var writer = new StreamWriter(response.OutputStream); - _sseClients.Add(writer); - - // Immediately send the "endpoint" event with the POST URL - await writer.WriteLineAsync("event: endpoint"); - if (UseFullUrlForEndpointEvent) - { - await writer.WriteLineAsync($"data: {MessageEndpoint}"); - } - else - { - await writer.WriteLineAsync($"data: {_messagePath}"); - } - await writer.WriteLineAsync(); // blank line to end an SSE message - await writer.FlushAsync(ct); - - _logger.LogInformation("New SSE client connected."); - _connectionEstablished.TrySetResult(); // Signal connection is ready - - try - { - // Keep the connection open by "pinging" or just waiting - // until the client disconnects or the server is canceled. - while (!ct.IsCancellationRequested && response.OutputStream.CanWrite) - { - _logger.LogDebug("SSE connection alive check"); - // Optionally do a periodic no-op to keep connection alive: - await writer.WriteLineAsync(": keep-alive"); - await writer.FlushAsync(ct); - await Task.Delay(TimeSpan.FromSeconds(5), ct); - } - } - catch (TaskCanceledException) - { - // This is expected on shutdown - _logger.LogInformation("SSE connection cancelled (expected on shutdown)"); - } - catch (IOException) - { - // Client likely disconnected - _logger.LogInformation("SSE client disconnected"); - } - finally - { - // Remove this writer from bag (we're disposing it anyway) - _sseClients.TryTake(out _); - _logger.LogInformation("SSE client disconnected."); - } - } - - // Add method to wait for connection - public Task WaitForConnectionAsync(TimeSpan timeout) => - _connectionEstablished.Task.WaitAsync(timeout); - - /// - /// Handle POST /message endpoint. - /// Echo the content back to the caller and broadcast it over SSE as well. - /// - private async Task HandlePostMessageAsync(HttpListenerContext context, CancellationToken cancellationToken) - { - var request = context.Request; - var response = context.Response; - - try - { - using var reader = new StreamReader(request.InputStream); - string content = await reader.ReadToEndAsync(cancellationToken); - - var jsonRpcNotification = JsonSerializer.Deserialize(content); - if (jsonRpcNotification != null && jsonRpcNotification.Method != RequestMethods.Initialize) - { - // Test server so just ignore notifications - - // Set status code to 202 Accepted - response.StatusCode = 202; - - // Write "accepted" as the response body - using var writer = new StreamWriter(response.OutputStream); - await writer.WriteAsync("accepted"); - await writer.FlushAsync(cancellationToken); - return; - } - - var jsonRpcRequest = JsonSerializer.Deserialize(content); - - if (jsonRpcRequest != null) - { - if (jsonRpcRequest.Method == RequestMethods.Initialize) - { - await HandleInitializationRequest(response, jsonRpcRequest); - } - else - { - // Set status code to 202 Accepted - response.StatusCode = 202; - - // Write "accepted" as the response body - using var writer = new StreamWriter(response.OutputStream); - await writer.WriteAsync("accepted"); - await writer.FlushAsync(cancellationToken); - - // Then process the message and send the response via SSE - if (jsonRpcRequest != null) - { - // Process the request and generate a response - var responseMessage = await ProcessRequestAsync(jsonRpcRequest); - - // Send the response via the SSE connection instead of HTTP response - await SendSseMessageAsync(responseMessage); - } - } - } - } - catch (Exception ex) - { - await SendJsonRpcErrorAsync(response, null, -32603, "Internal error", ex.Message); - response.StatusCode = 500; - } - finally - { - response.Close(); - } - } - - private static Task ProcessRequestAsync(JsonRpcRequest jsonRpcRequest) - { - // This is a test server so we just echo back the request - return Task.FromResult(new JsonRpcResponse - { - Id = jsonRpcRequest.Id, - Result = jsonRpcRequest.Params! - }); - } - - private async Task SendSseMessageAsync(JsonRpcResponse jsonRpcResponse) - { - // This is a test server so we just send to all connected clients - await BroadcastMessageAsync(JsonSerializer.Serialize(jsonRpcResponse)); - } - - private static async Task SendJsonRpcErrorAsync(HttpListenerResponse response, RequestId? id, int code, string message, string? data = null) - { - var errorResponse = new JsonRpcError - { - Id = id ?? new RequestId("error"), - JsonRpc = "2.0", - Error = new JsonRpcErrorDetail - { - Code = code, - Message = message, - Data = data - } - }; - - response.StatusCode = 200; // Always 200 for JSON-RPC - response.ContentType = "application/json"; - await using var writer = new StreamWriter(response.OutputStream); - await writer.WriteAsync(JsonSerializer.Serialize(errorResponse)); - } - - private static async Task HandleInitializationRequest(HttpListenerResponse response, JsonRpcRequest jsonRpcRequest) - { - // We don't need to validate the client's initialization request for the test - // Just send back a valid server initialization response - var jsonRpcResponse = new JsonRpcResponse() - { - Id = jsonRpcRequest.Id, - Result = JsonSerializer.SerializeToNode(new InitializeResult - { - ProtocolVersion = "2024-11-05", - Capabilities = new(), - ServerInfo = new() - { - Name = "ExampleServer", - Version = "1.0.0" - } - }) - }; - - // Echo back to the HTTP response - response.StatusCode = 200; - response.ContentType = "application/json"; - - await using var writer = new StreamWriter(response.OutputStream); - await writer.WriteAsync(JsonSerializer.Serialize(jsonRpcResponse)); - } - - /// - /// Broadcast a message to all currently connected SSE clients. - /// - /// The raw string to send - public async Task BroadcastMessageAsync(string message) - { - foreach (var client in _sseClients.ToArray()) // ToArray to avoid mutation issues - { - try - { - // SSE requires "event: " + "data: " + blank line - await client.WriteLineAsync("event: message"); - await client.WriteLineAsync($"data: {message}"); - await client.WriteLineAsync(); - await client.FlushAsync(); - } - catch (IOException) - { - // Client may have disconnected. We let them get cleaned up on next iteration. - } - catch (ObjectDisposedException) - { - // Stream is disposed, ignore. - } - } - } - - public async ValueTask DisposeAsync() - { - await _cts.CancelAsync(); - - if (_serverTask != null) - { - try - { - await Task.WhenAny(_serverTask, Task.Delay(2000)); - } - catch (TaskCanceledException) - { - // ignore - } - } - - _listener.Close(); - _cts.Dispose(); - - _logger.LogInformation("Test SSE server stopped"); - } - - /// - /// Send a test notification to all connected SSE clients. - /// - /// - /// - public async Task SendTestNotificationAsync(string content) - { - var notification = new JsonRpcNotification - { - JsonRpc = "2.0", - Method = "test/notification", - Params = JsonSerializer.SerializeToNode(new { message = content }), - }; - - var serialized = JsonSerializer.Serialize(notification); - await BroadcastMessageAsync(serialized); - } -} diff --git a/tests/ModelContextProtocol.Tests/Utils/KestrelInMemoryConnection.cs b/tests/ModelContextProtocol.Tests/Utils/KestrelInMemoryConnection.cs new file mode 100644 index 00000000..e22a5578 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Utils/KestrelInMemoryConnection.cs @@ -0,0 +1,113 @@ +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http.Features; +using System.IO.Pipelines; + +namespace ModelContextProtocol.Tests.Utils; + +public sealed class KestrelInMemoryConnection : ConnectionContext +{ + private readonly Pipe _clientToServerPipe = new(); + private readonly Pipe _serverToClientPipe = new(); + private readonly CancellationTokenSource _connectionClosedCts = new CancellationTokenSource(); + private readonly IFeatureCollection _features = new FeatureCollection(); + + public KestrelInMemoryConnection() + { + ConnectionClosed = _connectionClosedCts.Token; + Transport = new DuplexPipe + { + Input = _clientToServerPipe.Reader, + Output = _serverToClientPipe.Writer, + }; + Application = new DuplexPipe + { + Input = _serverToClientPipe.Reader, + Output = _clientToServerPipe.Writer, + }; + ClientStream = new DuplexStream(Application, _connectionClosedCts); + } + + public IDuplexPipe Application { get; } + public Stream ClientStream { get; } + + public override IDuplexPipe Transport { get; set; } + public override string ConnectionId { get; set; } = Guid.NewGuid().ToString("N"); + + public override IFeatureCollection Features => _features; + + public override IDictionary Items { get; set; } = new Dictionary(); + + public override ValueTask DisposeAsync() + { + // This is called by Kestrel. The client should dispose the DuplexStream which + // completes the other half of these pipes. + _serverToClientPipe.Writer.Complete(); + _serverToClientPipe.Reader.Complete(); + // Disposing the CTS without waiting for the client would be problematic + // except we always dispose the HttpClient before Kestrel in our tests. + _connectionClosedCts.Dispose(); + return base.DisposeAsync(); + } + + private class DuplexPipe : IDuplexPipe + { + public required PipeReader Input { get; init; } + public required PipeWriter Output { get; init; } + } + + private class DuplexStream(IDuplexPipe duplexPipe, CancellationTokenSource connectionClosedCts) : Stream + { + private readonly Stream _readStream = duplexPipe.Input.AsStream(); + private readonly Stream _writeStream = duplexPipe.Output.AsStream(); + + public override bool CanRead => true; + public override bool CanWrite => true; + public override bool CanSeek => false; + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => _readStream.ReadAsync(buffer, offset, count, cancellationToken); + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + => _readStream.ReadAsync(buffer, cancellationToken); + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => _writeStream.WriteAsync(buffer, offset, count, cancellationToken); + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + => _writeStream.WriteAsync(buffer, cancellationToken); + + public override Task FlushAsync(CancellationToken cancellationToken) + => _writeStream.FlushAsync(cancellationToken); + + public override async ValueTask DisposeAsync() + { + // Signal to the server the the client has closed the connection, and dispose the client-half of the Pipes. + await connectionClosedCts.CancelAsync(); + duplexPipe.Input.Complete(); + duplexPipe.Output.Complete(); + await base.DisposeAsync(); + } + + // Unsupported stuff + public override long Length => throw new NotSupportedException(); + + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + + // Don't bother supporting sync or APM methods. SocketsHttpHandler shouldn't use them. + public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + public override void Flush() => throw new NotSupportedException(); + } +} diff --git a/tests/ModelContextProtocol.Tests/Utils/KestrelInMemoryTest.cs b/tests/ModelContextProtocol.Tests/Utils/KestrelInMemoryTest.cs new file mode 100644 index 00000000..eb094576 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Utils/KestrelInMemoryTest.cs @@ -0,0 +1,34 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Connections; +using Microsoft.Extensions.DependencyInjection; + +namespace ModelContextProtocol.Tests.Utils; + +public class KestrelInMemoryTest : LoggedTest +{ + private readonly KestrelInMemoryTransport _inMemoryTransport = new(); + + public KestrelInMemoryTest(ITestOutputHelper testOutputHelper) + : base(testOutputHelper) + { + Builder = WebApplication.CreateSlimBuilder(); + Builder.Services.AddSingleton(_inMemoryTransport); + Builder.Services.AddSingleton(LoggerProvider); + } + + public WebApplicationBuilder Builder { get; } + + public HttpClient CreateHttpClient() + { + var socketsHttpHandler = new SocketsHttpHandler() + { + ConnectCallback = (context, token) => + { + var connection = _inMemoryTransport.CreateConnection(); + return new(connection.ClientStream); + }, + }; + + return new HttpClient(socketsHttpHandler); + } +} diff --git a/tests/ModelContextProtocol.Tests/Utils/KestrelInMemoryTransport.cs b/tests/ModelContextProtocol.Tests/Utils/KestrelInMemoryTransport.cs new file mode 100644 index 00000000..586b1650 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Utils/KestrelInMemoryTransport.cs @@ -0,0 +1,50 @@ +using Microsoft.AspNetCore.Connections; +using System.Net; +using System.Threading.Channels; + +namespace ModelContextProtocol.Tests.Utils; + +public sealed class KestrelInMemoryTransport : IConnectionListenerFactory, IConnectionListener +{ + private readonly Channel _acceptQueue = Channel.CreateUnbounded(); + private EndPoint? _endPoint; + + public EndPoint EndPoint => _endPoint ?? throw new InvalidOperationException("EndPoint is not set. Call BindAsync first."); + + public KestrelInMemoryConnection CreateConnection() + { + var connection = new KestrelInMemoryConnection(); + _acceptQueue.Writer.TryWrite(connection); + return connection; + } + + public async ValueTask AcceptAsync(CancellationToken cancellationToken = default) + { + if (await _acceptQueue.Reader.WaitToReadAsync(cancellationToken)) + { + while (_acceptQueue.Reader.TryRead(out var item)) + { + return item; + } + } + + return null; + } + + public ValueTask BindAsync(EndPoint endpoint, CancellationToken cancellationToken = default) + { + _endPoint = endpoint; + return new ValueTask(this); + } + + public ValueTask DisposeAsync() + { + return UnbindAsync(default); + } + + public ValueTask UnbindAsync(CancellationToken cancellationToken = default) + { + _acceptQueue.Writer.TryComplete(); + return default; + } +} From 62587fd985f5d4d2ed2640603f34a3fbedb8363b Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Sun, 6 Apr 2025 15:52:42 -0700 Subject: [PATCH 04/11] Use ApplicationStopping token for SSE responses --- .../McpEndpointRouteBuilderExtensions.cs | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs index f6a2ffc5..09f28358 100644 --- a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs @@ -4,6 +4,7 @@ using Microsoft.AspNetCore.Routing.Patterns; using Microsoft.AspNetCore.WebUtilities; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using ModelContextProtocol.Protocol.Messages; @@ -54,14 +55,18 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo var loggerFactory = endpoints.ServiceProvider.GetRequiredService(); var optionsSnapshot = endpoints.ServiceProvider.GetRequiredService>(); var optionsFactory = endpoints.ServiceProvider.GetRequiredService>(); + var hostApplicationLifetime = endpoints.ServiceProvider.GetRequiredService(); var routeGroup = endpoints.MapGroup(pattern); routeGroup.MapGet("/sse", async context => { - var response = context.Response; - var requestAborted = context.RequestAborted; + // If the server is shutting down, we need to cancel all SSE connections immediately without waiting for HostOptions.ShutdownTimeout + // which defaults to 30 seconds. + using var sseCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted, hostApplicationLifetime.ApplicationStopping); + var cancellationToken = sseCts.Token; + var response = context.Response; response.Headers.ContentType = "text/event-stream"; response.Headers.CacheControl = "no-cache,no-store"; @@ -80,12 +85,12 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo if (configureOptionsAsync is not null) { options = optionsFactory.Create(Options.DefaultName); - await configureOptionsAsync.Invoke(context, options, requestAborted); + await configureOptionsAsync.Invoke(context, options, cancellationToken); } try { - var transportTask = transport.RunAsync(cancellationToken: requestAborted); + var transportTask = transport.RunAsync(cancellationToken); await using var mcpServer = McpServerFactory.Create(transport, options, loggerFactory, endpoints.ServiceProvider); context.Features.Set(mcpServer); @@ -93,7 +98,7 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo try { runSessionAsync ??= RunSession; - await runSessionAsync(context, mcpServer, requestAborted); + await runSessionAsync(context, mcpServer, cancellationToken); } finally { @@ -101,7 +106,7 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo await transportTask; } } - catch (OperationCanceledException) when (requestAborted.IsCancellationRequested) + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) { // RequestAborted always triggers when the client disconnects before a complete response body is written, // but this is how SSE connections are typically closed. From 339db549c759e699d216ecfcaf4bec44d5bb6058 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Sun, 6 Apr 2025 16:56:31 -0700 Subject: [PATCH 05/11] Avoid listening with both socket and in-memory transport --- .../ModelContextProtocol.TestSseServer/Program.cs | 15 +++++++++++++-- .../Utils/KestrelInMemoryTest.cs | 4 ++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index 4f704672..bc339899 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -377,11 +377,14 @@ public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvide { Console.WriteLine("Starting server..."); - var builder = WebApplication.CreateSlimBuilder(args); + var builder = WebApplication.CreateEmptyBuilder(new() + { + Args = args, + }); if (kestrelTransport is null) { - int port = args.Length > 0 && uint.TryParse(args[0], out var parsedPort) ? (int)parsedPort : 3001; + int port = args.Length > 0 && uint.TryParse(args[0], out var parsedPort) ? (int)parsedPort : 3001; builder.WebHost.ConfigureKestrel(options => { options.ListenLocalhost(port); @@ -389,9 +392,15 @@ public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvide } else { + // Add passed-in transport before calling UseKestrelCore() to avoid the SocketsHttpHandler getting added. builder.Services.AddSingleton(kestrelTransport); } + builder.WebHost.UseKestrelCore(); + builder.Services.AddLogging(); + builder.Services.AddRoutingCore(); + + builder.Logging.AddConsole(); ConfigureSerilog(builder.Logging); if (loggerProvider is not null) { @@ -401,6 +410,8 @@ public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvide builder.Services.AddMcpServer(ConfigureOptions); var app = builder.Build(); + app.UseRouting(); + app.UseEndpoints(_ => { }); app.MapMcp(); diff --git a/tests/ModelContextProtocol.Tests/Utils/KestrelInMemoryTest.cs b/tests/ModelContextProtocol.Tests/Utils/KestrelInMemoryTest.cs index eb094576..5e440dcd 100644 --- a/tests/ModelContextProtocol.Tests/Utils/KestrelInMemoryTest.cs +++ b/tests/ModelContextProtocol.Tests/Utils/KestrelInMemoryTest.cs @@ -1,6 +1,7 @@ using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; namespace ModelContextProtocol.Tests.Utils; @@ -11,7 +12,10 @@ public class KestrelInMemoryTest : LoggedTest public KestrelInMemoryTest(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { + // Use SlimBuilder instead of EmptyBuilder to avoid having to call UseRouting() and UseEndpoints(_ => { }) + // or a helper that does the same every test. But clear out the existing socket transport to avoid potential port conflicts. Builder = WebApplication.CreateSlimBuilder(); + Builder.Services.RemoveAll(); Builder.Services.AddSingleton(_inMemoryTransport); Builder.Services.AddSingleton(LoggerProvider); } From 575c2bc83b38f70a033c56296a0b8b68fa8aba68 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Sun, 6 Apr 2025 18:31:09 -0700 Subject: [PATCH 06/11] Use RegisterNotificationHandler in MapMcp tests --- .../SseIntegrationTests.cs | 54 +++++-------------- 1 file changed, 12 insertions(+), 42 deletions(-) diff --git a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs index fd5eddd6..044a8000 100644 --- a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs @@ -72,56 +72,26 @@ public async Task ConnectAndReceiveNotification_InMemoryServer() var receivedNotification = new TaskCompletionSource(); await using var app = Builder.Build(); - app.MapMcp(configureOptionsAsync: (httpContext, mcpServerOptions, CancellationToken) => + app.MapMcp(runSessionAsync: (httpContext, mcpServer, cancellationToken) => { - mcpServerOptions.Capabilities = new() + mcpServer.RegisterNotificationHandler("test/notification", async (notification, cancellationToken) => { - NotificationHandlers = - [ - new("test/notification", async notification => - { - Assert.Equal("Hello from client!", notification.Params?["message"]?.GetValue()); - - var server = httpContext.Features.GetRequiredFeature(); - - // REVIEW: Where is the CancellationToken for notification handlers? - // The httpContext.RequestAborted trick will not always work once we fully support the HTTP streaming spec. - await server.SendNotificationAsync("test/notification", new { message = "Hello from server!" }, cancellationToken: httpContext.RequestAborted); - }), - ], - }; - - return Task.CompletedTask; + Assert.Equal("Hello from client!", notification.Params?["message"]?.GetValue()); + var server = httpContext.Features.GetRequiredFeature(); + await server.SendNotificationAsync("test/notification", new { message = "Hello from server!" }, cancellationToken: cancellationToken); + }); + return mcpServer.RunAsync(cancellationToken); }); await app.StartAsync(TestContext.Current.CancellationToken); using var httpClient = CreateHttpClient(); - await using var mcpClient = await ConnectMcpClient(httpClient, new() - { - Capabilities = new() - { - NotificationHandlers = [new("test/notification", args => - { - var msg = args.Params?["message"]?.GetValue(); - receivedNotification.SetResult(msg); - - return Task.CompletedTask; - })], - }, - }); + await using var mcpClient = await ConnectMcpClient(httpClient); - await using var mcpClient2 = await ConnectMcpClient(httpClient, new() + mcpClient.RegisterNotificationHandler("test/notification", (args, ca) => { - Capabilities = new() - { - NotificationHandlers = [new("test/notification", args => - { - var msg = args.Params?["message"]?.GetValue(); - receivedNotification.SetResult(msg); - - return Task.CompletedTask; - })], - }, + var msg = args.Params?["message"]?.GetValue(); + receivedNotification.SetResult(msg); + return Task.CompletedTask; }); // Send a test message through POST endpoint From b19ab58deaa879b82e6c151622c3ff0f29229bb7 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Sun, 6 Apr 2025 18:45:30 -0700 Subject: [PATCH 07/11] Work around SocketsHttpHandler bug where it doesn't call DispseAsync on the stream returned by the ConnectCallback - Move workaround to KestrelInMemoryConnection instead of SseClientSessionTransport which is product code --- .../Protocol/Transport/SseClientSessionTransport.cs | 9 +-------- .../Utils/KestrelInMemoryConnection.cs | 5 ++--- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs index 078ff452..c7542da8 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs @@ -204,15 +204,8 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); - // A lot of methods in HttpClient, including HttpConnection.FillAsync, do not currently respect any - // cancellation tokens so we have to cancel awaiting ourselves. - // I tried disposing the HttpRequestMessage and HttpResponseMessage instead, but neither of these canceled - // the ongoing body read, and we don't want to dispose an HttpClient we may not own. - var sseEventEnumerator = SseParser.Create(stream).EnumerateAsync(cancellationToken).GetAsyncEnumerator(); - while (await sseEventEnumerator.MoveNextAsync().AsTask().WaitAsync(cancellationToken).ConfigureAwait(false)) + await foreach (SseItem sseEvent in SseParser.Create(stream).EnumerateAsync(cancellationToken).ConfigureAwait(false)) { - SseItem sseEvent = sseEventEnumerator.Current; - switch (sseEvent.EventType) { case "endpoint": diff --git a/tests/ModelContextProtocol.Tests/Utils/KestrelInMemoryConnection.cs b/tests/ModelContextProtocol.Tests/Utils/KestrelInMemoryConnection.cs index e22a5578..d641d876 100644 --- a/tests/ModelContextProtocol.Tests/Utils/KestrelInMemoryConnection.cs +++ b/tests/ModelContextProtocol.Tests/Utils/KestrelInMemoryConnection.cs @@ -77,13 +77,12 @@ public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationTo public override Task FlushAsync(CancellationToken cancellationToken) => _writeStream.FlushAsync(cancellationToken); - public override async ValueTask DisposeAsync() + protected override void Dispose(bool disposing) { // Signal to the server the the client has closed the connection, and dispose the client-half of the Pipes. - await connectionClosedCts.CancelAsync(); + connectionClosedCts.Cancel(); duplexPipe.Input.Complete(); duplexPipe.Output.Complete(); - await base.DisposeAsync(); } // Unsupported stuff From e5f152d95bd2e0f74161dad7933762ed3b27c3c9 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Sun, 6 Apr 2025 18:57:30 -0700 Subject: [PATCH 08/11] Remove bogus assert --- .../Protocol/NotificationHandlerTests.cs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs index ed1f4b20..27ad6037 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs @@ -225,7 +225,7 @@ public async Task DisposeAsyncCompletesImmediatelyWhenInvokedFromHandler(int num var releaseHandler = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); IAsyncDisposable? registration = null; - registration = client.RegisterNotificationHandler(NotificationName, async (notification, cancellationToken) => + await using var _ = registration = client.RegisterNotificationHandler(NotificationName, async (notification, cancellationToken) => { for (int i = 0; i < numberOfDisposals; i++) { @@ -240,9 +240,5 @@ public async Task DisposeAsyncCompletesImmediatelyWhenInvokedFromHandler(int num await _server.SendNotificationAsync(NotificationName, TestContext.Current.CancellationToken); await handlerRunning.Task; - - ValueTask disposal = registration.DisposeAsync(); - Assert.True(disposal.IsCompletedSuccessfully); - await disposal; } } From a9a4b8e6e5b5aeb1e62424ffb0a5e261cdae1106 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Sun, 6 Apr 2025 18:59:02 -0700 Subject: [PATCH 09/11] Move more lines into try --- .../McpEndpointRouteBuilderExtensions.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs index 09f28358..31149b5f 100644 --- a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs @@ -91,12 +91,12 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo try { var transportTask = transport.RunAsync(cancellationToken); - await using var mcpServer = McpServerFactory.Create(transport, options, loggerFactory, endpoints.ServiceProvider); - - context.Features.Set(mcpServer); try { + await using var mcpServer = McpServerFactory.Create(transport, options, loggerFactory, endpoints.ServiceProvider); + context.Features.Set(mcpServer); + runSessionAsync ??= RunSession; await runSessionAsync(context, mcpServer, cancellationToken); } From 467a0862bff6c91d8ae67db127c0b7f9740dab76 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Sun, 6 Apr 2025 19:18:54 -0700 Subject: [PATCH 10/11] Fix race in ConnectAndReceiveNotification_InMemoryServer test - Go back to not processing messages until IMcpServer.RunAsync is called - style: consistently use mcpServer from parameter rather than features Failed test run: https://github.com/modelcontextprotocol/csharp-sdk/actions/runs/14299175970/job/40070612363?pr=225 --- tests/ModelContextProtocol.Tests/SseIntegrationTests.cs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs index 044a8000..41c3d343 100644 --- a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs @@ -77,8 +77,7 @@ public async Task ConnectAndReceiveNotification_InMemoryServer() mcpServer.RegisterNotificationHandler("test/notification", async (notification, cancellationToken) => { Assert.Equal("Hello from client!", notification.Params?["message"]?.GetValue()); - var server = httpContext.Features.GetRequiredFeature(); - await server.SendNotificationAsync("test/notification", new { message = "Hello from server!" }, cancellationToken: cancellationToken); + await mcpServer.SendNotificationAsync("test/notification", new { message = "Hello from server!" }, cancellationToken: cancellationToken); }); return mcpServer.RunAsync(cancellationToken); }); From 407201316f7ce566dfcb317d4fedde5231af3ef6 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Sun, 6 Apr 2025 19:20:01 -0700 Subject: [PATCH 11/11] Add the rest of the changes that were supposed to be in the last commit - Get most of the changes that were supposed to be in the last commit - Go back to not processing messages until IMcpServer.RunAsync is called Failed test run: https://github.com/modelcontextprotocol/csharp-sdk/actions/runs/14299175970/job/40070612363?pr=225 --- src/ModelContextProtocol/Client/McpClient.cs | 5 ++++- src/ModelContextProtocol/Server/McpServer.cs | 12 +++++++----- src/ModelContextProtocol/Shared/McpEndpoint.cs | 15 ++++++++++----- 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/src/ModelContextProtocol/Client/McpClient.cs b/src/ModelContextProtocol/Client/McpClient.cs index 9ab22b54..cc43cfa9 100644 --- a/src/ModelContextProtocol/Client/McpClient.cs +++ b/src/ModelContextProtocol/Client/McpClient.cs @@ -106,7 +106,10 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) { // Connect transport _sessionTransport = await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false); - StartSession(_sessionTransport); + InitializeSession(_sessionTransport); + // We don't want the ConnectAsync token to cancel the session after we've successfully connected. + // The base class handles cleaning up the session in DisposeAsync without our help. + StartSession(_sessionTransport, fullSessionCancellationToken: CancellationToken.None); // Perform initialization sequence using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index 22e1584c..47df1c9a 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -18,6 +18,8 @@ internal sealed class McpServer : McpEndpoint, IMcpServer Version = DefaultAssemblyName.Version?.ToString() ?? "1.0.0", }; + private readonly ITransport _sessionTransport; + private readonly EventHandler? _toolsChangedDelegate; private readonly EventHandler? _promptsChangedDelegate; @@ -41,6 +43,7 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? options ??= new(); + _sessionTransport = transport; ServerOptions = options; Services = serviceProvider; _endpointName = $"Server ({options.ServerInfo?.Name ?? DefaultImplementation.Name} {options.ServerInfo?.Version ?? DefaultImplementation.Version})"; @@ -81,8 +84,8 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? prompts.Changed += _promptsChangedDelegate; } - // And start the session. - StartSession(transport); + // And initialize the session. + InitializeSession(transport); } public ServerCapabilities? ServerCapabilities { get; set; } @@ -112,9 +115,8 @@ public async Task RunAsync(CancellationToken cancellationToken = default) try { - using var _ = cancellationToken.Register(static s => ((McpServer)s!).CancelSession(), this); - // The McpServer ctor always calls StartSession, so MessageProcessingTask is always set. - await MessageProcessingTask!.ConfigureAwait(false); + StartSession(_sessionTransport, fullSessionCancellationToken: cancellationToken); + await MessageProcessingTask.ConfigureAwait(false); } finally { diff --git a/src/ModelContextProtocol/Shared/McpEndpoint.cs b/src/ModelContextProtocol/Shared/McpEndpoint.cs index f24e9134..cc227778 100644 --- a/src/ModelContextProtocol/Shared/McpEndpoint.cs +++ b/src/ModelContextProtocol/Shared/McpEndpoint.cs @@ -5,6 +5,7 @@ using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Server; using ModelContextProtocol.Utils; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Reflection; @@ -62,12 +63,16 @@ public IAsyncDisposable RegisterNotificationHandler(string method, Func protected Task? MessageProcessingTask { get; private set; } - [MemberNotNull(nameof(MessageProcessingTask))] - protected void StartSession(ITransport sessionTransport) + protected void InitializeSession(ITransport sessionTransport) { - _sessionCts = new CancellationTokenSource(); _session = new McpSession(this is IMcpServer, sessionTransport, EndpointName, RequestHandlers, NotificationHandlers, _logger); - MessageProcessingTask = _session.ProcessMessagesAsync(_sessionCts.Token); + } + + [MemberNotNull(nameof(MessageProcessingTask))] + protected void StartSession(ITransport sessionTransport, CancellationToken fullSessionCancellationToken) + { + _sessionCts = CancellationTokenSource.CreateLinkedTokenSource(fullSessionCancellationToken); + MessageProcessingTask = GetSessionOrThrow().ProcessMessagesAsync(_sessionCts.Token); } protected void CancelSession() => _sessionCts?.Cancel(); @@ -122,5 +127,5 @@ public virtual async ValueTask DisposeUnsynchronizedAsync() } protected McpSession GetSessionOrThrow() - => _session ?? throw new InvalidOperationException($"This should be unreachable from public API! Call {nameof(StartSession)} before sending messages."); + => _session ?? throw new InvalidOperationException($"This should be unreachable from public API! Call {nameof(InitializeSession)} before sending messages."); }