diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs
index 891408d0..31149b5f 100644
--- a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs
+++ b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs
@@ -1,8 +1,10 @@
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Routing;
+using Microsoft.AspNetCore.Routing.Patterns;
using Microsoft.AspNetCore.WebUtilities;
using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using ModelContextProtocol.Protocol.Messages;
@@ -10,6 +12,7 @@
using ModelContextProtocol.Server;
using ModelContextProtocol.Utils.Json;
using System.Collections.Concurrent;
+using System.Diagnostics.CodeAnalysis;
using System.Security.Cryptography;
namespace Microsoft.AspNetCore.Builder;
@@ -23,25 +26,54 @@ public static class McpEndpointRouteBuilderExtensions
/// Sets up endpoints for handling MCP HTTP Streaming transport.
///
/// The web application to attach MCP HTTP endpoints.
- /// Provides an optional asynchronous callback for handling new MCP sessions.
+ /// The route pattern prefix to map to.
+ /// Configure per-session options.
+ /// Provides an optional asynchronous callback for handling new MCP sessions.
/// Returns a builder for configuring additional endpoint conventions like authorization policies.
- public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints, Func? runSession = null)
+ public static IEndpointConventionBuilder MapMcp(
+ this IEndpointRouteBuilder endpoints,
+ [StringSyntax("Route")] string pattern = "",
+ Func? configureOptionsAsync = null,
+ Func? runSessionAsync = null)
+ => endpoints.MapMcp(RoutePatternFactory.Parse(pattern), configureOptionsAsync, runSessionAsync);
+
+ ///
+ /// Sets up endpoints for handling MCP HTTP Streaming transport.
+ ///
+ /// The web application to attach MCP HTTP endpoints.
+ /// The route pattern prefix to map to.
+ /// Configure per-session options.
+ /// Provides an optional asynchronous callback for handling new MCP sessions.
+ /// Returns a builder for configuring additional endpoint conventions like authorization policies.
+ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints,
+ RoutePattern pattern,
+ Func? configureOptionsAsync = null,
+ Func? runSessionAsync = null)
{
ConcurrentDictionary _sessions = new(StringComparer.Ordinal);
var loggerFactory = endpoints.ServiceProvider.GetRequiredService();
- var mcpServerOptions = endpoints.ServiceProvider.GetRequiredService>();
+ var optionsSnapshot = endpoints.ServiceProvider.GetRequiredService>();
+ var optionsFactory = endpoints.ServiceProvider.GetRequiredService>();
+ var hostApplicationLifetime = endpoints.ServiceProvider.GetRequiredService();
- var routeGroup = endpoints.MapGroup("");
+ var routeGroup = endpoints.MapGroup(pattern);
routeGroup.MapGet("/sse", async context =>
{
- var response = context.Response;
- var requestAborted = context.RequestAborted;
+ // If the server is shutting down, we need to cancel all SSE connections immediately without waiting for HostOptions.ShutdownTimeout
+ // which defaults to 30 seconds.
+ using var sseCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted, hostApplicationLifetime.ApplicationStopping);
+ var cancellationToken = sseCts.Token;
+ var response = context.Response;
response.Headers.ContentType = "text/event-stream";
response.Headers.CacheControl = "no-cache,no-store";
+ // Make sure we disable all response buffering for SSE
+ context.Response.Headers.ContentEncoding = "identity";
+ context.Features.GetRequiredFeature().DisableBuffering();
+
var sessionId = MakeNewSessionId();
await using var transport = new SseResponseStreamTransport(response.Body, $"/message?sessionId={sessionId}");
if (!_sessions.TryAdd(sessionId, transport))
@@ -49,19 +81,24 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo
throw new Exception($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created.");
}
- try
+ var options = optionsSnapshot.Value;
+ if (configureOptionsAsync is not null)
{
- // Make sure we disable all response buffering for SSE
- context.Response.Headers.ContentEncoding = "identity";
- context.Features.GetRequiredFeature().DisableBuffering();
+ options = optionsFactory.Create(Options.DefaultName);
+ await configureOptionsAsync.Invoke(context, options, cancellationToken);
+ }
- var transportTask = transport.RunAsync(cancellationToken: requestAborted);
- await using var server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider);
+ try
+ {
+ var transportTask = transport.RunAsync(cancellationToken);
try
{
- runSession ??= RunSession;
- await runSession(context, server, requestAborted);
+ await using var mcpServer = McpServerFactory.Create(transport, options, loggerFactory, endpoints.ServiceProvider);
+ context.Features.Set(mcpServer);
+
+ runSessionAsync ??= RunSession;
+ await runSessionAsync(context, mcpServer, cancellationToken);
}
finally
{
@@ -69,7 +106,7 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo
await transportTask;
}
}
- catch (OperationCanceledException) when (requestAborted.IsCancellationRequested)
+ catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
// RequestAborted always triggers when the client disconnects before a complete response body is written,
// but this is how SSE connections are typically closed.
diff --git a/src/ModelContextProtocol/Client/McpClient.cs b/src/ModelContextProtocol/Client/McpClient.cs
index 9ab22b54..cc43cfa9 100644
--- a/src/ModelContextProtocol/Client/McpClient.cs
+++ b/src/ModelContextProtocol/Client/McpClient.cs
@@ -106,7 +106,10 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)
{
// Connect transport
_sessionTransport = await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false);
- StartSession(_sessionTransport);
+ InitializeSession(_sessionTransport);
+ // We don't want the ConnectAsync token to cancel the session after we've successfully connected.
+ // The base class handles cleaning up the session in DisposeAsync without our help.
+ StartSession(_sessionTransport, fullSessionCancellationToken: CancellationToken.None);
// Perform initialization sequence
using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs
index 854b0402..c7542da8 100644
--- a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs
+++ b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs
@@ -151,16 +151,14 @@ private async Task CloseAsync()
{
try
{
- if (!_connectionCts.IsCancellationRequested)
- {
- await _connectionCts.CancelAsync().ConfigureAwait(false);
- _connectionCts.Dispose();
- }
+ await _connectionCts.CancelAsync().ConfigureAwait(false);
if (_receiveTask != null)
{
await _receiveTask.ConfigureAwait(false);
}
+
+ _connectionCts.Dispose();
}
finally
{
diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs
index 22e1584c..47df1c9a 100644
--- a/src/ModelContextProtocol/Server/McpServer.cs
+++ b/src/ModelContextProtocol/Server/McpServer.cs
@@ -18,6 +18,8 @@ internal sealed class McpServer : McpEndpoint, IMcpServer
Version = DefaultAssemblyName.Version?.ToString() ?? "1.0.0",
};
+ private readonly ITransport _sessionTransport;
+
private readonly EventHandler? _toolsChangedDelegate;
private readonly EventHandler? _promptsChangedDelegate;
@@ -41,6 +43,7 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory?
options ??= new();
+ _sessionTransport = transport;
ServerOptions = options;
Services = serviceProvider;
_endpointName = $"Server ({options.ServerInfo?.Name ?? DefaultImplementation.Name} {options.ServerInfo?.Version ?? DefaultImplementation.Version})";
@@ -81,8 +84,8 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory?
prompts.Changed += _promptsChangedDelegate;
}
- // And start the session.
- StartSession(transport);
+ // And initialize the session.
+ InitializeSession(transport);
}
public ServerCapabilities? ServerCapabilities { get; set; }
@@ -112,9 +115,8 @@ public async Task RunAsync(CancellationToken cancellationToken = default)
try
{
- using var _ = cancellationToken.Register(static s => ((McpServer)s!).CancelSession(), this);
- // The McpServer ctor always calls StartSession, so MessageProcessingTask is always set.
- await MessageProcessingTask!.ConfigureAwait(false);
+ StartSession(_sessionTransport, fullSessionCancellationToken: cancellationToken);
+ await MessageProcessingTask.ConfigureAwait(false);
}
finally
{
diff --git a/src/ModelContextProtocol/Shared/McpEndpoint.cs b/src/ModelContextProtocol/Shared/McpEndpoint.cs
index f24e9134..cc227778 100644
--- a/src/ModelContextProtocol/Shared/McpEndpoint.cs
+++ b/src/ModelContextProtocol/Shared/McpEndpoint.cs
@@ -5,6 +5,7 @@
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Server;
using ModelContextProtocol.Utils;
+using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
@@ -62,12 +63,16 @@ public IAsyncDisposable RegisterNotificationHandler(string method, Func
protected Task? MessageProcessingTask { get; private set; }
- [MemberNotNull(nameof(MessageProcessingTask))]
- protected void StartSession(ITransport sessionTransport)
+ protected void InitializeSession(ITransport sessionTransport)
{
- _sessionCts = new CancellationTokenSource();
_session = new McpSession(this is IMcpServer, sessionTransport, EndpointName, RequestHandlers, NotificationHandlers, _logger);
- MessageProcessingTask = _session.ProcessMessagesAsync(_sessionCts.Token);
+ }
+
+ [MemberNotNull(nameof(MessageProcessingTask))]
+ protected void StartSession(ITransport sessionTransport, CancellationToken fullSessionCancellationToken)
+ {
+ _sessionCts = CancellationTokenSource.CreateLinkedTokenSource(fullSessionCancellationToken);
+ MessageProcessingTask = GetSessionOrThrow().ProcessMessagesAsync(_sessionCts.Token);
}
protected void CancelSession() => _sessionCts?.Cancel();
@@ -122,5 +127,5 @@ public virtual async ValueTask DisposeUnsynchronizedAsync()
}
protected McpSession GetSessionOrThrow()
- => _session ?? throw new InvalidOperationException($"This should be unreachable from public API! Call {nameof(StartSession)} before sending messages.");
+ => _session ?? throw new InvalidOperationException($"This should be unreachable from public API! Call {nameof(InitializeSession)} before sending messages.");
}
diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs
index 43ae3207..bc339899 100644
--- a/tests/ModelContextProtocol.TestSseServer/Program.cs
+++ b/tests/ModelContextProtocol.TestSseServer/Program.cs
@@ -1,4 +1,5 @@
-using ModelContextProtocol.Protocol.Types;
+using Microsoft.AspNetCore.Connections;
+using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Server;
using Serilog;
using System.Text;
@@ -372,18 +373,34 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st
};
}
- public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvider = null, CancellationToken cancellationToken = default)
+ public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvider = null, IConnectionListenerFactory? kestrelTransport = null, CancellationToken cancellationToken = default)
{
Console.WriteLine("Starting server...");
- int port = args.Length > 0 && uint.TryParse(args[0], out var parsedPort) ? (int)parsedPort : 3001;
-
- var builder = WebApplication.CreateSlimBuilder(args);
- builder.WebHost.ConfigureKestrel(options =>
+ var builder = WebApplication.CreateEmptyBuilder(new()
{
- options.ListenLocalhost(port);
+ Args = args,
});
+ if (kestrelTransport is null)
+ {
+ int port = args.Length > 0 && uint.TryParse(args[0], out var parsedPort) ? (int)parsedPort : 3001;
+ builder.WebHost.ConfigureKestrel(options =>
+ {
+ options.ListenLocalhost(port);
+ });
+ }
+ else
+ {
+ // Add passed-in transport before calling UseKestrelCore() to avoid the SocketsHttpHandler getting added.
+ builder.Services.AddSingleton(kestrelTransport);
+ }
+
+ builder.WebHost.UseKestrelCore();
+ builder.Services.AddLogging();
+ builder.Services.AddRoutingCore();
+
+ builder.Logging.AddConsole();
ConfigureSerilog(builder.Logging);
if (loggerProvider is not null)
{
@@ -393,6 +410,8 @@ public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvide
builder.Services.AddMcpServer(ConfigureOptions);
var app = builder.Build();
+ app.UseRouting();
+ app.UseEndpoints(_ => { });
app.MapMcp();
diff --git a/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs b/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs
new file mode 100644
index 00000000..3dbcfe7c
--- /dev/null
+++ b/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs
@@ -0,0 +1,123 @@
+using ModelContextProtocol.Client;
+using ModelContextProtocol.Protocol.Transport;
+using ModelContextProtocol.Protocol.Types;
+using ModelContextProtocol.Tests.Utils;
+
+namespace ModelContextProtocol.Tests;
+
+public class DockerEverythingServerTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper)
+{
+ /// Port number to be grabbed by the next test.
+ private static int s_nextPort = 3000;
+
+ // If the tests run concurrently against different versions of the runtime, tests can conflict with
+ // each other in the ports set up for interacting with containers. Ensure that such suites running
+ // against different TFMs use different port numbers.
+ private static readonly int s_portOffset = 1000 * (Environment.Version.Major switch
+ {
+ int v when v >= 8 => Environment.Version.Major - 7,
+ _ => 0,
+ });
+
+ private static int CreatePortNumber() => Interlocked.Increment(ref s_nextPort) + s_portOffset;
+
+ public static bool IsDockerAvailable => EverythingSseServerFixture.IsDockerAvailable;
+
+ [Fact(Skip = "docker is not available", SkipUnless = nameof(IsDockerAvailable))]
+ [Trait("Execution", "Manual")]
+ public async Task ConnectAndReceiveMessage_EverythingServerWithSse()
+ {
+ int port = CreatePortNumber();
+
+ await using var fixture = new EverythingSseServerFixture(port);
+ await fixture.StartAsync();
+
+ var defaultOptions = new McpClientOptions
+ {
+ ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" }
+ };
+
+ var defaultConfig = new McpServerConfig
+ {
+ Id = "everything",
+ Name = "Everything",
+ TransportType = TransportTypes.Sse,
+ TransportOptions = [],
+ Location = $"http://localhost:{port}/sse"
+ };
+
+ // Create client and run tests
+ await using var client = await McpClientFactory.CreateAsync(
+ defaultConfig,
+ defaultOptions,
+ loggerFactory: LoggerFactory,
+ cancellationToken: TestContext.Current.CancellationToken);
+ var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
+
+ // assert
+ Assert.NotEmpty(tools);
+ }
+
+ [Fact(Skip = "docker is not available", SkipUnless = nameof(IsDockerAvailable))]
+ [Trait("Execution", "Manual")]
+ public async Task Sampling_Sse_EverythingServer()
+ {
+ int port = CreatePortNumber();
+
+ await using var fixture = new EverythingSseServerFixture(port);
+ await fixture.StartAsync();
+
+ var defaultConfig = new McpServerConfig
+ {
+ Id = "everything",
+ Name = "Everything",
+ TransportType = TransportTypes.Sse,
+ TransportOptions = [],
+ Location = $"http://localhost:{port}/sse"
+ };
+
+ int samplingHandlerCalls = 0;
+ var defaultOptions = new McpClientOptions
+ {
+ Capabilities = new()
+ {
+ Sampling = new()
+ {
+ SamplingHandler = (_, _, _) =>
+ {
+ samplingHandlerCalls++;
+ return Task.FromResult(new CreateMessageResult
+ {
+ Model = "test-model",
+ Role = "assistant",
+ Content = new Content
+ {
+ Type = "text",
+ Text = "Test response"
+ }
+ });
+ },
+ },
+ },
+ };
+
+ await using var client = await McpClientFactory.CreateAsync(
+ defaultConfig,
+ defaultOptions,
+ loggerFactory: LoggerFactory,
+ cancellationToken: TestContext.Current.CancellationToken);
+
+ // Call the server's sampleLLM tool which should trigger our sampling handler
+ var result = await client.CallToolAsync("sampleLLM", new Dictionary
+ {
+ ["prompt"] = "Test prompt",
+ ["maxTokens"] = 100
+ }, cancellationToken: TestContext.Current.CancellationToken);
+
+ // assert
+ Assert.NotNull(result);
+ var textContent = Assert.Single(result.Content);
+ Assert.Equal("text", textContent.Type);
+ Assert.False(string.IsNullOrEmpty(textContent.Text));
+ }
+}
diff --git a/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs
index ed1f4b20..27ad6037 100644
--- a/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs
+++ b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs
@@ -225,7 +225,7 @@ public async Task DisposeAsyncCompletesImmediatelyWhenInvokedFromHandler(int num
var releaseHandler = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
IAsyncDisposable? registration = null;
- registration = client.RegisterNotificationHandler(NotificationName, async (notification, cancellationToken) =>
+ await using var _ = registration = client.RegisterNotificationHandler(NotificationName, async (notification, cancellationToken) =>
{
for (int i = 0; i < numberOfDisposals; i++)
{
@@ -240,9 +240,5 @@ public async Task DisposeAsyncCompletesImmediatelyWhenInvokedFromHandler(int num
await _server.SendNotificationAsync(NotificationName, TestContext.Current.CancellationToken);
await handlerRunning.Task;
-
- ValueTask disposal = registration.DisposeAsync();
- Assert.True(disposal.IsCompletedSuccessfully);
- await disposal;
}
}
diff --git a/tests/ModelContextProtocol.Tests/Server/MapMcpTests.cs b/tests/ModelContextProtocol.Tests/Server/MapMcpTests.cs
new file mode 100644
index 00000000..f8a80d66
--- /dev/null
+++ b/tests/ModelContextProtocol.Tests/Server/MapMcpTests.cs
@@ -0,0 +1,19 @@
+using Microsoft.AspNetCore.Builder;
+using ModelContextProtocol.Tests.Utils;
+
+namespace ModelContextProtocol.Tests.Server;
+
+public class MapMcpTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper)
+{
+ [Fact]
+ public async Task Allows_Customizing_Route()
+ {
+ await using var app = Builder.Build();
+ app.MapMcp("/mcp");
+ await app.StartAsync(TestContext.Current.CancellationToken);
+
+ using var httpClient = CreateHttpClient();
+ using var response = await httpClient.GetAsync("http://localhost/mcp/sse", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken);
+ response.EnsureSuccessStatusCode();
+ }
+}
diff --git a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs
index 94cc786e..41c3d343 100644
--- a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs
+++ b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs
@@ -1,238 +1,162 @@
-using Microsoft.Extensions.Logging;
+using Microsoft.AspNetCore.Builder;
+using Microsoft.AspNetCore.Http;
+using Microsoft.AspNetCore.Http.Features;
+using Microsoft.AspNetCore.Routing;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Options;
using ModelContextProtocol.Client;
+using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Transport;
-using ModelContextProtocol.Protocol.Types;
+using ModelContextProtocol.Server;
using ModelContextProtocol.Tests.Utils;
+using ModelContextProtocol.Utils.Json;
namespace ModelContextProtocol.Tests;
-public class SseIntegrationTests(ITestOutputHelper outputHelper) : LoggedTest(outputHelper)
+public class SseIntegrationTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper)
{
- /// Port number to be grabbed by the next test.
- private static int s_nextPort = 3000;
-
- // If the tests run concurrently against different versions of the runtime, tests can conflict with
- // each other in the ports set up for interacting with containers. Ensure that such suites running
- // against different TFMs use different port numbers.
- private static readonly int s_portOffset = 1000 * (Environment.Version.Major switch
+ private McpServerConfig DefaultServerConfig = new()
{
- int v when v >= 8 => Environment.Version.Major - 7,
- _ => 0,
- });
+ Id = "test_server",
+ Name = "In-memory Test Server",
+ TransportType = TransportTypes.Sse,
+ TransportOptions = [],
+ Location = $"http://localhost/sse"
+ };
+
+ private Task ConnectMcpClient(HttpClient httpClient, McpClientOptions? clientOptions = null)
+ => McpClientFactory.CreateAsync(
+ DefaultServerConfig,
+ clientOptions,
+ (_, _) => new SseClientTransport(new(), DefaultServerConfig, httpClient, LoggerFactory),
+ LoggerFactory,
+ TestContext.Current.CancellationToken);
- private static int CreatePortNumber() => Interlocked.Increment(ref s_nextPort) + s_portOffset;
[Fact]
public async Task ConnectAndReceiveMessage_InMemoryServer()
{
- // Arrange
- await using InMemoryTestSseServer server = new(CreatePortNumber(), LoggerFactory.CreateLogger());
- await server.StartAsync();
+ await using var app = Builder.Build();
+ app.MapMcp();
+ await app.StartAsync(TestContext.Current.CancellationToken);
- var defaultConfig = new McpServerConfig
- {
- Id = "test_server",
- Name = "In-memory Test Server",
- TransportType = TransportTypes.Sse,
- TransportOptions = [],
- Location = $"http://localhost:{server.Port}/sse"
- };
-
- // Act
- await using var client = await McpClientFactory.CreateAsync(
- defaultConfig,
- loggerFactory: LoggerFactory,
- cancellationToken: TestContext.Current.CancellationToken);
-
- // Wait for SSE connection to be established
- await server.WaitForConnectionAsync(TimeSpan.FromSeconds(10));
+ using var httpClient = CreateHttpClient();
+ await using var mcpClient = await ConnectMcpClient(httpClient);
// Send a test message through POST endpoint
- await client.SendNotificationAsync("test/message", new { message = "Hello, SSE!" }, cancellationToken: TestContext.Current.CancellationToken);
+ await mcpClient.SendNotificationAsync("test/message", new { message = "Hello, SSE!" }, cancellationToken: TestContext.Current.CancellationToken);
- // Assert
Assert.True(true);
}
[Fact]
- [Trait("Execution", "Manual")]
- public async Task ConnectAndReceiveMessage_EverythingServerWithSse()
+ public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventUri()
{
- Assert.SkipWhen(!EverythingSseServerFixture.IsDockerAvailable, "docker is not available");
+ await using var app = Builder.Build();
+ MapAbsoluteEndpointUriMcp(app);
+ await app.StartAsync(TestContext.Current.CancellationToken);
- int port = CreatePortNumber();
+ using var httpClient = CreateHttpClient();
+ await using var mcpClient = await ConnectMcpClient(httpClient);
- await using var fixture = new EverythingSseServerFixture(port);
- await fixture.StartAsync();
-
- var defaultOptions = new McpClientOptions
- {
- ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" }
- };
+ // Send a test message through POST endpoint
+ await mcpClient.SendNotificationAsync("test/message", new { message = "Hello, SSE!" }, cancellationToken: TestContext.Current.CancellationToken);
- var defaultConfig = new McpServerConfig
- {
- Id = "everything",
- Name = "Everything",
- TransportType = TransportTypes.Sse,
- TransportOptions = [],
- Location = $"http://localhost:{port}/sse"
- };
-
- // Create client and run tests
- await using var client = await McpClientFactory.CreateAsync(
- defaultConfig,
- defaultOptions,
- loggerFactory: LoggerFactory,
- cancellationToken: TestContext.Current.CancellationToken);
- var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
-
- // assert
- Assert.NotEmpty(tools);
+ Assert.True(true);
}
[Fact]
- [Trait("Execution", "Manual")]
- public async Task Sampling_Sse_EverythingServer()
+ public async Task ConnectAndReceiveNotification_InMemoryServer()
{
- Assert.SkipWhen(!EverythingSseServerFixture.IsDockerAvailable, "docker is not available");
-
- int port = CreatePortNumber();
-
- await using var fixture = new EverythingSseServerFixture(port);
- await fixture.StartAsync();
+ var receivedNotification = new TaskCompletionSource();
- var defaultConfig = new McpServerConfig
- {
- Id = "everything",
- Name = "Everything",
- TransportType = TransportTypes.Sse,
- TransportOptions = [],
- Location = $"http://localhost:{port}/sse"
- };
-
- int samplingHandlerCalls = 0;
- var defaultOptions = new McpClientOptions
+ await using var app = Builder.Build();
+ app.MapMcp(runSessionAsync: (httpContext, mcpServer, cancellationToken) =>
{
- Capabilities = new()
- {
- Sampling = new()
- {
- SamplingHandler = (_, _, _) =>
- {
- samplingHandlerCalls++;
- return Task.FromResult(new CreateMessageResult
- {
- Model = "test-model",
- Role = "assistant",
- Content = new Content
- {
- Type = "text",
- Text = "Test response"
- }
- });
- },
- },
- },
- };
-
- await using var client = await McpClientFactory.CreateAsync(
- defaultConfig,
- defaultOptions,
- loggerFactory: LoggerFactory,
- cancellationToken: TestContext.Current.CancellationToken);
-
- // Call the server's sampleLLM tool which should trigger our sampling handler
- var result = await client.CallToolAsync("sampleLLM", new Dictionary
+ mcpServer.RegisterNotificationHandler("test/notification", async (notification, cancellationToken) =>
{
- ["prompt"] = "Test prompt",
- ["maxTokens"] = 100
- }, cancellationToken: TestContext.Current.CancellationToken);
-
- // assert
- Assert.NotNull(result);
- var textContent = Assert.Single(result.Content);
- Assert.Equal("text", textContent.Type);
- Assert.False(string.IsNullOrEmpty(textContent.Text));
- }
+ Assert.Equal("Hello from client!", notification.Params?["message"]?.GetValue());
+ await mcpServer.SendNotificationAsync("test/notification", new { message = "Hello from server!" }, cancellationToken: cancellationToken);
+ });
+ return mcpServer.RunAsync(cancellationToken);
+ });
+ await app.StartAsync(TestContext.Current.CancellationToken);
- [Fact]
- public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventUri()
- {
- // Arrange
- await using InMemoryTestSseServer server = new(CreatePortNumber(), LoggerFactory.CreateLogger());
- server.UseFullUrlForEndpointEvent = true;
- await server.StartAsync();
+ using var httpClient = CreateHttpClient();
+ await using var mcpClient = await ConnectMcpClient(httpClient);
- var defaultConfig = new McpServerConfig
+ mcpClient.RegisterNotificationHandler("test/notification", (args, ca) =>
{
- Id = "test_server",
- Name = "In-memory Test Server",
- TransportType = TransportTypes.Sse,
- TransportOptions = [],
- Location = $"http://localhost:{server.Port}/sse"
- };
-
- // Act
- await using var client = await McpClientFactory.CreateAsync(
- defaultConfig,
- loggerFactory: LoggerFactory,
- cancellationToken: TestContext.Current.CancellationToken);
-
- // Wait for SSE connection to be established
- await server.WaitForConnectionAsync(TimeSpan.FromSeconds(10));
+ var msg = args.Params?["message"]?.GetValue();
+ receivedNotification.SetResult(msg);
+ return Task.CompletedTask;
+ });
// Send a test message through POST endpoint
- await client.SendNotificationAsync("test/message", new { message = "Hello, SSE!" }, cancellationToken: TestContext.Current.CancellationToken);
+ await mcpClient.SendNotificationAsync("test/notification", new { message = "Hello from client!" }, cancellationToken: TestContext.Current.CancellationToken);
- // Assert
- Assert.True(true);
+ var message = await receivedNotification.Task.WaitAsync(TimeSpan.FromSeconds(10), TestContext.Current.CancellationToken);
+ Assert.Equal("Hello from server!", message);
}
- [Fact]
- public async Task ConnectAndReceiveNotification_InMemoryServer()
+ private static void MapAbsoluteEndpointUriMcp(IEndpointRouteBuilder endpoints)
{
- // Arrange
- await using InMemoryTestSseServer server = new(CreatePortNumber(), LoggerFactory.CreateLogger());
- await server.StartAsync();
+ var loggerFactory = endpoints.ServiceProvider.GetRequiredService();
+ var optionsSnapshot = endpoints.ServiceProvider.GetRequiredService>();
- var defaultConfig = new McpServerConfig
+ var routeGroup = endpoints.MapGroup("");
+ SseResponseStreamTransport? session = null;
+
+ routeGroup.MapGet("/sse", async context =>
{
- Id = "test_server",
- Name = "In-memory Test Server",
- TransportType = TransportTypes.Sse,
- TransportOptions = [],
- Location = $"http://localhost:{server.Port}/sse"
- };
-
- // Act
- var receivedNotification = new TaskCompletionSource();
- await using var client = await McpClientFactory.CreateAsync(
- defaultConfig,
- new()
- {
- Capabilities = new()
- {
- NotificationHandlers = [new("test/notification", (notification, cancellationToken) =>
- {
- var msg = notification.Params?["message"]?.GetValue();
- receivedNotification.SetResult(msg);
+ var response = context.Response;
+ var requestAborted = context.RequestAborted;
- return Task.CompletedTask;
- })],
- },
- },
- loggerFactory: LoggerFactory,
- cancellationToken: TestContext.Current.CancellationToken);
+ response.Headers.ContentType = "text/event-stream";
- // Wait for SSE connection to be established
- await server.WaitForConnectionAsync(TimeSpan.FromSeconds(10));
+ await using var transport = new SseResponseStreamTransport(response.Body, "http://localhost/message");
+ session = transport;
- // Act
- await server.SendTestNotificationAsync("Hello from server!");
+ try
+ {
+ var transportTask = transport.RunAsync(cancellationToken: requestAborted);
+ await using var server = McpServerFactory.Create(transport, optionsSnapshot.Value, loggerFactory, endpoints.ServiceProvider);
- // Assert
- var message = await receivedNotification.Task.WaitAsync(TimeSpan.FromSeconds(10), TestContext.Current.CancellationToken);
- Assert.Equal("Hello from server!", message);
+ try
+ {
+ await server.RunAsync(requestAborted);
+ }
+ finally
+ {
+ await transport.DisposeAsync();
+ await transportTask;
+ }
+ }
+ catch (OperationCanceledException) when (requestAborted.IsCancellationRequested)
+ {
+ // RequestAborted always triggers when the client disconnects before a complete response body is written,
+ // but this is how SSE connections are typically closed.
+ }
+ });
+
+ routeGroup.MapPost("/message", async context =>
+ {
+ if (session is null)
+ {
+ await Results.BadRequest("Session not started.").ExecuteAsync(context);
+ return;
+ }
+ var message = (IJsonRpcMessage?)await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)), context.RequestAborted);
+ if (message is null)
+ {
+ await Results.BadRequest("No message in request body.").ExecuteAsync(context);
+ return;
+ }
+
+ await session.OnMessageReceivedAsync(message, context.RequestAborted);
+ context.Response.StatusCode = StatusCodes.Status202Accepted;
+ await context.Response.WriteAsync("Accepted");
+ });
}
}
diff --git a/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs b/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs
index 1720bc69..6b7b474e 100644
--- a/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs
+++ b/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs
@@ -1,4 +1,6 @@
-using ModelContextProtocol.Protocol.Transport;
+using Microsoft.Extensions.Logging;
+using ModelContextProtocol.Client;
+using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Test.Utils;
using ModelContextProtocol.Tests.Utils;
using ModelContextProtocol.TestSseServer;
@@ -7,28 +9,52 @@ namespace ModelContextProtocol.Tests;
public class SseServerIntegrationTestFixture : IAsyncDisposable
{
+ private readonly KestrelInMemoryTransport _inMemoryTransport = new();
+
private readonly Task _serverTask;
private readonly CancellationTokenSource _stopCts = new();
+ // XUnit's ITestOutputHelper is created per test, while this fixture is used for
+ // multiple tests, so this dispatches the output to the current test.
private readonly DelegatingTestOutputHelper _delegatingTestOutputHelper = new();
- public McpServerConfig DefaultConfig { get; }
+ private McpServerConfig DefaultServerConfig { get; } = new McpServerConfig
+ {
+ Id = "test_server",
+ Name = "TestServer",
+ TransportType = TransportTypes.Sse,
+ TransportOptions = [],
+ Location = $"http://localhost/sse"
+ };
public SseServerIntegrationTestFixture()
{
- // Ensure that test suites running against different TFMs and possibly concurrently use different port numbers.
- int port = 3001 + Environment.Version.Major;
+ var socketsHttpHandler = new SocketsHttpHandler()
+ {
+ ConnectCallback = (context, token) =>
+ {
+ var connection = _inMemoryTransport.CreateConnection();
+ return new(connection.ClientStream);
+ },
+ };
- DefaultConfig = new McpServerConfig
+ HttpClient = new HttpClient(socketsHttpHandler)
{
- Id = "test_server",
- Name = "TestServer",
- TransportType = TransportTypes.Sse,
- TransportOptions = [],
- Location = $"http://localhost:{port}/sse"
+ BaseAddress = new Uri(DefaultServerConfig.Location),
};
+ _serverTask = Program.MainAsync([], new XunitLoggerProvider(_delegatingTestOutputHelper), _inMemoryTransport, _stopCts.Token);
+ }
- _serverTask = Program.MainAsync([port.ToString()], new XunitLoggerProvider(_delegatingTestOutputHelper), _stopCts.Token);
+ public HttpClient HttpClient { get; }
+
+ public Task ConnectMcpClientAsync(McpClientOptions? options, ILoggerFactory loggerFactory)
+ {
+ return McpClientFactory.CreateAsync(
+ DefaultServerConfig,
+ options,
+ (_, _) => new SseClientTransport(new(), DefaultServerConfig, HttpClient, loggerFactory),
+ loggerFactory,
+ TestContext.Current.CancellationToken);
}
public void Initialize(ITestOutputHelper output)
@@ -44,7 +70,10 @@ public void TestCompleted()
public async ValueTask DisposeAsync()
{
_delegatingTestOutputHelper.CurrentTestOutputHelper = null;
+
+ HttpClient.Dispose();
_stopCts.Cancel();
+
try
{
await _serverTask.ConfigureAwait(false);
@@ -52,6 +81,7 @@ public async ValueTask DisposeAsync()
catch (OperationCanceledException)
{
}
+
_stopCts.Dispose();
}
}
diff --git a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs
index eaead4e0..7913bbb2 100644
--- a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs
+++ b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs
@@ -25,19 +25,9 @@ public override void Dispose()
private Task GetClientAsync(McpClientOptions? options = null)
{
- return McpClientFactory.CreateAsync(
- _fixture.DefaultConfig,
- options,
- loggerFactory: LoggerFactory,
- cancellationToken: TestContext.Current.CancellationToken);
+ return _fixture.ConnectMcpClientAsync(options, LoggerFactory);
}
- private HttpClient GetHttpClient() =>
- new()
- {
- BaseAddress = new(_fixture.DefaultConfig.Location!),
- };
-
[Fact]
public async Task ConnectAndPing_Sse_TestServer()
{
@@ -283,8 +273,7 @@ public async Task CallTool_Sse_EchoServer_Concurrently()
[Fact]
public async Task EventSourceResponse_Includes_ExpectedHeaders()
{
- using var httpClient = GetHttpClient();
- using var sseResponse = await httpClient.GetAsync("", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken);
+ using var sseResponse = await _fixture.HttpClient.GetAsync("", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken);
sseResponse.EnsureSuccessStatusCode();
@@ -300,8 +289,7 @@ public async Task EventSourceStream_Includes_MessageEventType()
{
// Simulate our own MCP client handshake using a plain HttpClient so we can look for "event: message"
// in the raw SSE response stream which is not exposed by the real MCP client.
- using var httpClient = GetHttpClient();
- await using var sseResponse = await httpClient.GetStreamAsync("", TestContext.Current.CancellationToken);
+ await using var sseResponse = await _fixture.HttpClient.GetStreamAsync("", TestContext.Current.CancellationToken);
using var streamReader = new StreamReader(sseResponse);
var endpointEvent = await streamReader.ReadLineAsync(TestContext.Current.CancellationToken);
@@ -317,7 +305,7 @@ public async Task EventSourceStream_Includes_MessageEventType()
""";
using (var initializeRequestBody = new StringContent(initializeRequest, Encoding.UTF8, "application/json"))
{
- var response = await httpClient.PostAsync(messageEndpoint, initializeRequestBody, TestContext.Current.CancellationToken);
+ var response = await _fixture.HttpClient.PostAsync(messageEndpoint, initializeRequestBody, TestContext.Current.CancellationToken);
Assert.Equal(HttpStatusCode.Accepted, response.StatusCode);
}
@@ -326,7 +314,7 @@ public async Task EventSourceStream_Includes_MessageEventType()
""";
using (var initializedNotificationBody = new StringContent(initializedNotification, Encoding.UTF8, "application/json"))
{
- var response = await httpClient.PostAsync(messageEndpoint, initializedNotificationBody, TestContext.Current.CancellationToken);
+ var response = await _fixture.HttpClient.PostAsync(messageEndpoint, initializedNotificationBody, TestContext.Current.CancellationToken);
Assert.Equal(HttpStatusCode.Accepted, response.StatusCode);
}
diff --git a/tests/ModelContextProtocol.Tests/Transport/SseResponseStreamTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseResponseStreamTransportTests.cs
new file mode 100644
index 00000000..220746cd
--- /dev/null
+++ b/tests/ModelContextProtocol.Tests/Transport/SseResponseStreamTransportTests.cs
@@ -0,0 +1,27 @@
+using ModelContextProtocol.Protocol.Transport;
+using ModelContextProtocol.Tests.Utils;
+using System.IO.Pipelines;
+
+namespace ModelContextProtocol.Tests.Transport;
+
+public class SseResponseStreamTransportTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper)
+{
+ [Fact]
+ public async Task Can_Customize_MessageEndpoint()
+ {
+ var responsePipe = new Pipe();
+
+ await using var transport = new SseResponseStreamTransport(responsePipe.Writer.AsStream(), "/my-message-endpoint");
+ var transportRunTask = transport.RunAsync(TestContext.Current.CancellationToken);
+
+ using var responseStreamReader = new StreamReader(responsePipe.Reader.AsStream());
+ var firstLine = await responseStreamReader.ReadLineAsync(TestContext.Current.CancellationToken);
+ Assert.Equal("event: endpoint", firstLine);
+
+ var secondLine = await responseStreamReader.ReadLineAsync(TestContext.Current.CancellationToken);
+ Assert.Equal("data: /my-message-endpoint", secondLine);
+
+ responsePipe.Reader.Complete();
+ responsePipe.Writer.Complete();
+ }
+}
diff --git a/tests/ModelContextProtocol.Tests/Utils/InMemoryTestSseServer.cs b/tests/ModelContextProtocol.Tests/Utils/InMemoryTestSseServer.cs
deleted file mode 100644
index 7d7122a8..00000000
--- a/tests/ModelContextProtocol.Tests/Utils/InMemoryTestSseServer.cs
+++ /dev/null
@@ -1,382 +0,0 @@
-using System.Collections.Concurrent;
-using System.Net;
-using System.Text.Json;
-using ModelContextProtocol.Protocol.Messages;
-using Microsoft.Extensions.Logging;
-using Microsoft.Extensions.Logging.Abstractions;
-using ModelContextProtocol.Protocol.Types;
-
-namespace ModelContextProtocol.Tests.Utils;
-
-public sealed class InMemoryTestSseServer : IAsyncDisposable
-{
- private readonly HttpListener _listener;
- private readonly CancellationTokenSource _cts;
- private readonly ILogger _logger;
- private Task? _serverTask;
- private readonly TaskCompletionSource _connectionEstablished = new();
-
- // SSE endpoint for GET
- private readonly string _endpointPath;
- // POST endpoint
- private readonly string _messagePath;
-
- // Keep track of all open SSE connections (StreamWriters).
- private readonly ConcurrentBag _sseClients = [];
-
- public InMemoryTestSseServer(int port = 5000, ILogger? logger = null)
- {
- Port = port;
-
- _listener = new HttpListener();
- _listener.Prefixes.Add($"http://localhost:{port}/");
- _cts = new CancellationTokenSource();
- _logger = logger ?? NullLogger.Instance;
-
- _endpointPath = "/sse";
- _messagePath = "/message";
- }
-
- public int Port { get; }
-
- ///
- /// This is to be able to use the full URL for the endpoint event.
- ///
- public bool UseFullUrlForEndpointEvent { get; set; }
-
- ///
- /// Full URL for the SSE endpoint, e.g. "http://localhost:5000/sse"
- ///
- public string SseEndpoint
- => $"http://localhost:{_listener.Prefixes.First().Split(':')[2].TrimEnd('/')}{_endpointPath}";
-
- ///
- /// Full URL for the message endpoint, e.g. "http://localhost:5000/message"
- ///
- public string MessageEndpoint
- => $"http://localhost:{_listener.Prefixes.First().Split(':')[2].TrimEnd('/')}{_messagePath}";
-
- ///
- /// Starts the server so it can accept incoming connections and POST requests.
- ///
- public async Task StartAsync()
- {
- _listener.Start();
- _serverTask = HandleConnectionsAsync(_cts.Token);
-
- _logger.LogInformation("Test SSE server started on {Endpoint}", SseEndpoint);
- await Task.CompletedTask;
- }
-
- private async Task HandleConnectionsAsync(CancellationToken ct)
- {
- try
- {
- while (!ct.IsCancellationRequested)
- {
- var context = await _listener.GetContextAsync();
- _ = Task.Run(() => HandleRequestAsync(context, ct), ct);
- }
- }
- catch (OperationCanceledException)
- {
- // Ignore, we are shutting down
- }
- catch (Exception ex) when (!ct.IsCancellationRequested)
- {
- _logger.LogError(ex, "Error in SSE server connection handling");
- }
- }
-
- private async Task HandleRequestAsync(HttpListenerContext context, CancellationToken ct)
- {
- var request = context.Request;
- var response = context.Response;
-
- // Handle SSE endpoint
- if (request.HttpMethod.Equals("GET", StringComparison.OrdinalIgnoreCase)
- && request.Url?.AbsolutePath.Equals(_endpointPath, StringComparison.OrdinalIgnoreCase) == true)
- {
- await HandleSseConnectionAsync(context, ct);
- }
- // Handle POST /message
- else if (request.HttpMethod.Equals("POST", StringComparison.OrdinalIgnoreCase)
- && request.Url?.AbsolutePath.Equals(_messagePath, StringComparison.OrdinalIgnoreCase) == true)
- {
- await HandlePostMessageAsync(context, ct);
- }
- else
- {
- response.StatusCode = 404;
- response.Close();
- }
- }
-
- ///
- /// Handle Server-Sent Events (SSE) connection.
- /// Send the initial event: endpoint with the full POST URL.
- /// Keep the connection open until the server is disposed or the client disconnects.
- ///
- private async Task HandleSseConnectionAsync(HttpListenerContext context, CancellationToken ct)
- {
- var response = context.Response;
- response.ContentType = "text/event-stream";
- response.Headers["Cache-Control"] = "no-cache";
- // Ensures the response is never chunked away by the framework
- response.SendChunked = true;
- response.StatusCode = (int)HttpStatusCode.OK;
-
- using var writer = new StreamWriter(response.OutputStream);
- _sseClients.Add(writer);
-
- // Immediately send the "endpoint" event with the POST URL
- await writer.WriteLineAsync("event: endpoint");
- if (UseFullUrlForEndpointEvent)
- {
- await writer.WriteLineAsync($"data: {MessageEndpoint}");
- }
- else
- {
- await writer.WriteLineAsync($"data: {_messagePath}");
- }
- await writer.WriteLineAsync(); // blank line to end an SSE message
- await writer.FlushAsync(ct);
-
- _logger.LogInformation("New SSE client connected.");
- _connectionEstablished.TrySetResult(); // Signal connection is ready
-
- try
- {
- // Keep the connection open by "pinging" or just waiting
- // until the client disconnects or the server is canceled.
- while (!ct.IsCancellationRequested && response.OutputStream.CanWrite)
- {
- _logger.LogDebug("SSE connection alive check");
- // Optionally do a periodic no-op to keep connection alive:
- await writer.WriteLineAsync(": keep-alive");
- await writer.FlushAsync(ct);
- await Task.Delay(TimeSpan.FromSeconds(5), ct);
- }
- }
- catch (TaskCanceledException)
- {
- // This is expected on shutdown
- _logger.LogInformation("SSE connection cancelled (expected on shutdown)");
- }
- catch (IOException)
- {
- // Client likely disconnected
- _logger.LogInformation("SSE client disconnected");
- }
- finally
- {
- // Remove this writer from bag (we're disposing it anyway)
- _sseClients.TryTake(out _);
- _logger.LogInformation("SSE client disconnected.");
- }
- }
-
- // Add method to wait for connection
- public Task WaitForConnectionAsync(TimeSpan timeout) =>
- _connectionEstablished.Task.WaitAsync(timeout);
-
- ///
- /// Handle POST /message endpoint.
- /// Echo the content back to the caller and broadcast it over SSE as well.
- ///
- private async Task HandlePostMessageAsync(HttpListenerContext context, CancellationToken cancellationToken)
- {
- var request = context.Request;
- var response = context.Response;
-
- try
- {
- using var reader = new StreamReader(request.InputStream);
- string content = await reader.ReadToEndAsync(cancellationToken);
-
- var jsonRpcNotification = JsonSerializer.Deserialize(content);
- if (jsonRpcNotification != null && jsonRpcNotification.Method != RequestMethods.Initialize)
- {
- // Test server so just ignore notifications
-
- // Set status code to 202 Accepted
- response.StatusCode = 202;
-
- // Write "accepted" as the response body
- using var writer = new StreamWriter(response.OutputStream);
- await writer.WriteAsync("accepted");
- await writer.FlushAsync(cancellationToken);
- return;
- }
-
- var jsonRpcRequest = JsonSerializer.Deserialize(content);
-
- if (jsonRpcRequest != null)
- {
- if (jsonRpcRequest.Method == RequestMethods.Initialize)
- {
- await HandleInitializationRequest(response, jsonRpcRequest);
- }
- else
- {
- // Set status code to 202 Accepted
- response.StatusCode = 202;
-
- // Write "accepted" as the response body
- using var writer = new StreamWriter(response.OutputStream);
- await writer.WriteAsync("accepted");
- await writer.FlushAsync(cancellationToken);
-
- // Then process the message and send the response via SSE
- if (jsonRpcRequest != null)
- {
- // Process the request and generate a response
- var responseMessage = await ProcessRequestAsync(jsonRpcRequest);
-
- // Send the response via the SSE connection instead of HTTP response
- await SendSseMessageAsync(responseMessage);
- }
- }
- }
- }
- catch (Exception ex)
- {
- await SendJsonRpcErrorAsync(response, null, -32603, "Internal error", ex.Message);
- response.StatusCode = 500;
- }
- finally
- {
- response.Close();
- }
- }
-
- private static Task ProcessRequestAsync(JsonRpcRequest jsonRpcRequest)
- {
- // This is a test server so we just echo back the request
- return Task.FromResult(new JsonRpcResponse
- {
- Id = jsonRpcRequest.Id,
- Result = jsonRpcRequest.Params!
- });
- }
-
- private async Task SendSseMessageAsync(JsonRpcResponse jsonRpcResponse)
- {
- // This is a test server so we just send to all connected clients
- await BroadcastMessageAsync(JsonSerializer.Serialize(jsonRpcResponse));
- }
-
- private static async Task SendJsonRpcErrorAsync(HttpListenerResponse response, RequestId? id, int code, string message, string? data = null)
- {
- var errorResponse = new JsonRpcError
- {
- Id = id ?? new RequestId("error"),
- JsonRpc = "2.0",
- Error = new JsonRpcErrorDetail
- {
- Code = code,
- Message = message,
- Data = data
- }
- };
-
- response.StatusCode = 200; // Always 200 for JSON-RPC
- response.ContentType = "application/json";
- await using var writer = new StreamWriter(response.OutputStream);
- await writer.WriteAsync(JsonSerializer.Serialize(errorResponse));
- }
-
- private static async Task HandleInitializationRequest(HttpListenerResponse response, JsonRpcRequest jsonRpcRequest)
- {
- // We don't need to validate the client's initialization request for the test
- // Just send back a valid server initialization response
- var jsonRpcResponse = new JsonRpcResponse()
- {
- Id = jsonRpcRequest.Id,
- Result = JsonSerializer.SerializeToNode(new InitializeResult
- {
- ProtocolVersion = "2024-11-05",
- Capabilities = new(),
- ServerInfo = new()
- {
- Name = "ExampleServer",
- Version = "1.0.0"
- }
- })
- };
-
- // Echo back to the HTTP response
- response.StatusCode = 200;
- response.ContentType = "application/json";
-
- await using var writer = new StreamWriter(response.OutputStream);
- await writer.WriteAsync(JsonSerializer.Serialize(jsonRpcResponse));
- }
-
- ///
- /// Broadcast a message to all currently connected SSE clients.
- ///
- /// The raw string to send
- public async Task BroadcastMessageAsync(string message)
- {
- foreach (var client in _sseClients.ToArray()) // ToArray to avoid mutation issues
- {
- try
- {
- // SSE requires "event: " + "data: " + blank line
- await client.WriteLineAsync("event: message");
- await client.WriteLineAsync($"data: {message}");
- await client.WriteLineAsync();
- await client.FlushAsync();
- }
- catch (IOException)
- {
- // Client may have disconnected. We let them get cleaned up on next iteration.
- }
- catch (ObjectDisposedException)
- {
- // Stream is disposed, ignore.
- }
- }
- }
-
- public async ValueTask DisposeAsync()
- {
- await _cts.CancelAsync();
-
- if (_serverTask != null)
- {
- try
- {
- await Task.WhenAny(_serverTask, Task.Delay(2000));
- }
- catch (TaskCanceledException)
- {
- // ignore
- }
- }
-
- _listener.Close();
- _cts.Dispose();
-
- _logger.LogInformation("Test SSE server stopped");
- }
-
- ///
- /// Send a test notification to all connected SSE clients.
- ///
- ///
- ///
- public async Task SendTestNotificationAsync(string content)
- {
- var notification = new JsonRpcNotification
- {
- JsonRpc = "2.0",
- Method = "test/notification",
- Params = JsonSerializer.SerializeToNode(new { message = content }),
- };
-
- var serialized = JsonSerializer.Serialize(notification);
- await BroadcastMessageAsync(serialized);
- }
-}
diff --git a/tests/ModelContextProtocol.Tests/Utils/KestrelInMemoryConnection.cs b/tests/ModelContextProtocol.Tests/Utils/KestrelInMemoryConnection.cs
new file mode 100644
index 00000000..d641d876
--- /dev/null
+++ b/tests/ModelContextProtocol.Tests/Utils/KestrelInMemoryConnection.cs
@@ -0,0 +1,112 @@
+using Microsoft.AspNetCore.Connections;
+using Microsoft.AspNetCore.Http.Features;
+using System.IO.Pipelines;
+
+namespace ModelContextProtocol.Tests.Utils;
+
+public sealed class KestrelInMemoryConnection : ConnectionContext
+{
+ private readonly Pipe _clientToServerPipe = new();
+ private readonly Pipe _serverToClientPipe = new();
+ private readonly CancellationTokenSource _connectionClosedCts = new CancellationTokenSource();
+ private readonly IFeatureCollection _features = new FeatureCollection();
+
+ public KestrelInMemoryConnection()
+ {
+ ConnectionClosed = _connectionClosedCts.Token;
+ Transport = new DuplexPipe
+ {
+ Input = _clientToServerPipe.Reader,
+ Output = _serverToClientPipe.Writer,
+ };
+ Application = new DuplexPipe
+ {
+ Input = _serverToClientPipe.Reader,
+ Output = _clientToServerPipe.Writer,
+ };
+ ClientStream = new DuplexStream(Application, _connectionClosedCts);
+ }
+
+ public IDuplexPipe Application { get; }
+ public Stream ClientStream { get; }
+
+ public override IDuplexPipe Transport { get; set; }
+ public override string ConnectionId { get; set; } = Guid.NewGuid().ToString("N");
+
+ public override IFeatureCollection Features => _features;
+
+ public override IDictionary