Skip to content

Use Kestrel for all in-memory HTTP tests #225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 7, 2025
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Routing;
using Microsoft.AspNetCore.Routing.Patterns;
using Microsoft.AspNetCore.WebUtilities;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Server;
using ModelContextProtocol.Utils.Json;
using System.Collections.Concurrent;
using System.Diagnostics.CodeAnalysis;
using System.Security.Cryptography;

namespace Microsoft.AspNetCore.Builder;
Expand All @@ -23,53 +26,87 @@ public static class McpEndpointRouteBuilderExtensions
/// Sets up endpoints for handling MCP HTTP Streaming transport.
/// </summary>
/// <param name="endpoints">The web application to attach MCP HTTP endpoints.</param>
/// <param name="runSession">Provides an optional asynchronous callback for handling new MCP sessions.</param>
/// <param name="pattern">The route pattern prefix to map to.</param>
/// <param name="configureOptionsAsync">Configure per-session options.</param>
/// <param name="runSessionAsync">Provides an optional asynchronous callback for handling new MCP sessions.</param>
/// <returns>Returns a builder for configuring additional endpoint conventions like authorization policies.</returns>
public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints, Func<HttpContext, IMcpServer, CancellationToken, Task>? runSession = null)
public static IEndpointConventionBuilder MapMcp(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern = "",
Func<HttpContext, McpServerOptions, CancellationToken, Task>? configureOptionsAsync = null,
Func<HttpContext, IMcpServer, CancellationToken, Task>? runSessionAsync = null)
=> endpoints.MapMcp(RoutePatternFactory.Parse(pattern), configureOptionsAsync, runSessionAsync);

/// <summary>
/// Sets up endpoints for handling MCP HTTP Streaming transport.
/// </summary>
/// <param name="endpoints">The web application to attach MCP HTTP endpoints.</param>
/// <param name="pattern">The route pattern prefix to map to.</param>
/// <param name="configureOptionsAsync">Configure per-session options.</param>
/// <param name="runSessionAsync">Provides an optional asynchronous callback for handling new MCP sessions.</param>
/// <returns>Returns a builder for configuring additional endpoint conventions like authorization policies.</returns>
public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints,
RoutePattern pattern,
Func<HttpContext, McpServerOptions, CancellationToken, Task>? configureOptionsAsync = null,
Func<HttpContext, IMcpServer, CancellationToken, Task>? runSessionAsync = null)
{
ConcurrentDictionary<string, SseResponseStreamTransport> _sessions = new(StringComparer.Ordinal);

var loggerFactory = endpoints.ServiceProvider.GetRequiredService<ILoggerFactory>();
var mcpServerOptions = endpoints.ServiceProvider.GetRequiredService<IOptions<McpServerOptions>>();
var optionsSnapshot = endpoints.ServiceProvider.GetRequiredService<IOptions<McpServerOptions>>();
var optionsFactory = endpoints.ServiceProvider.GetRequiredService<IOptionsFactory<McpServerOptions>>();
var hostApplicationLifetime = endpoints.ServiceProvider.GetRequiredService<IHostApplicationLifetime>();

var routeGroup = endpoints.MapGroup("");
var routeGroup = endpoints.MapGroup(pattern);

routeGroup.MapGet("/sse", async context =>
{
var response = context.Response;
var requestAborted = context.RequestAborted;
// If the server is shutting down, we need to cancel all SSE connections immediately without waiting for HostOptions.ShutdownTimeout
// which defaults to 30 seconds.
using var sseCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted, hostApplicationLifetime.ApplicationStopping);
var cancellationToken = sseCts.Token;

var response = context.Response;
response.Headers.ContentType = "text/event-stream";
response.Headers.CacheControl = "no-cache,no-store";

// Make sure we disable all response buffering for SSE
context.Response.Headers.ContentEncoding = "identity";
context.Features.GetRequiredFeature<IHttpResponseBodyFeature>().DisableBuffering();

var sessionId = MakeNewSessionId();
await using var transport = new SseResponseStreamTransport(response.Body, $"/message?sessionId={sessionId}");
if (!_sessions.TryAdd(sessionId, transport))
{
throw new Exception($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created.");
}

try
var options = optionsSnapshot.Value;
if (configureOptionsAsync is not null)
{
// Make sure we disable all response buffering for SSE
context.Response.Headers.ContentEncoding = "identity";
context.Features.GetRequiredFeature<IHttpResponseBodyFeature>().DisableBuffering();
options = optionsFactory.Create(Options.DefaultName);
await configureOptionsAsync.Invoke(context, options, cancellationToken);
}

var transportTask = transport.RunAsync(cancellationToken: requestAborted);
await using var server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider);
try
{
var transportTask = transport.RunAsync(cancellationToken);

try
{
runSession ??= RunSession;
await runSession(context, server, requestAborted);
await using var mcpServer = McpServerFactory.Create(transport, options, loggerFactory, endpoints.ServiceProvider);
context.Features.Set(mcpServer);

runSessionAsync ??= RunSession;
await runSessionAsync(context, mcpServer, cancellationToken);
}
finally
{
await transport.DisposeAsync();
await transportTask;
}
}
catch (OperationCanceledException) when (requestAborted.IsCancellationRequested)
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
// RequestAborted always triggers when the client disconnects before a complete response body is written,
// but this is how SSE connections are typically closed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,14 @@ private async Task CloseAsync()
{
try
{
if (!_connectionCts.IsCancellationRequested)
{
await _connectionCts.CancelAsync().ConfigureAwait(false);
_connectionCts.Dispose();
}
await _connectionCts.CancelAsync().ConfigureAwait(false);

if (_receiveTask != null)
{
await _receiveTask.ConfigureAwait(false);
}

_connectionCts.Dispose();
}
finally
{
Expand Down
33 changes: 26 additions & 7 deletions tests/ModelContextProtocol.TestSseServer/Program.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ModelContextProtocol.Protocol.Types;
using Microsoft.AspNetCore.Connections;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Server;
using Serilog;
using System.Text;
Expand Down Expand Up @@ -372,18 +373,34 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st
};
}

public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvider = null, CancellationToken cancellationToken = default)
public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvider = null, IConnectionListenerFactory? kestrelTransport = null, CancellationToken cancellationToken = default)
{
Console.WriteLine("Starting server...");

int port = args.Length > 0 && uint.TryParse(args[0], out var parsedPort) ? (int)parsedPort : 3001;

var builder = WebApplication.CreateSlimBuilder(args);
builder.WebHost.ConfigureKestrel(options =>
var builder = WebApplication.CreateEmptyBuilder(new()
{
options.ListenLocalhost(port);
Args = args,
});

if (kestrelTransport is null)
{
int port = args.Length > 0 && uint.TryParse(args[0], out var parsedPort) ? (int)parsedPort : 3001;
builder.WebHost.ConfigureKestrel(options =>
{
options.ListenLocalhost(port);
});
}
else
{
// Add passed-in transport before calling UseKestrelCore() to avoid the SocketsHttpHandler getting added.
builder.Services.AddSingleton(kestrelTransport);
}

builder.WebHost.UseKestrelCore();
builder.Services.AddLogging();
builder.Services.AddRoutingCore();

builder.Logging.AddConsole();
ConfigureSerilog(builder.Logging);
if (loggerProvider is not null)
{
Expand All @@ -393,6 +410,8 @@ public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvide
builder.Services.AddMcpServer(ConfigureOptions);

var app = builder.Build();
app.UseRouting();
app.UseEndpoints(_ => { });

app.MapMcp();

Expand Down
123 changes: 123 additions & 0 deletions tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Tests.Utils;

namespace ModelContextProtocol.Tests;

public class DockerEverythingServerTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper)
{
/// <summary>Port number to be grabbed by the next test.</summary>
private static int s_nextPort = 3000;

// If the tests run concurrently against different versions of the runtime, tests can conflict with
// each other in the ports set up for interacting with containers. Ensure that such suites running
// against different TFMs use different port numbers.
private static readonly int s_portOffset = 1000 * (Environment.Version.Major switch
{
int v when v >= 8 => Environment.Version.Major - 7,
_ => 0,
});

private static int CreatePortNumber() => Interlocked.Increment(ref s_nextPort) + s_portOffset;

public static bool IsDockerAvailable => EverythingSseServerFixture.IsDockerAvailable;

[Fact(Skip = "docker is not available", SkipUnless = nameof(IsDockerAvailable))]
[Trait("Execution", "Manual")]
public async Task ConnectAndReceiveMessage_EverythingServerWithSse()
{
int port = CreatePortNumber();

await using var fixture = new EverythingSseServerFixture(port);
await fixture.StartAsync();

var defaultOptions = new McpClientOptions
{
ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" }
};

var defaultConfig = new McpServerConfig
{
Id = "everything",
Name = "Everything",
TransportType = TransportTypes.Sse,
TransportOptions = [],
Location = $"http://localhost:{port}/sse"
};

// Create client and run tests
await using var client = await McpClientFactory.CreateAsync(
defaultConfig,
defaultOptions,
loggerFactory: LoggerFactory,
cancellationToken: TestContext.Current.CancellationToken);
var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);

// assert
Assert.NotEmpty(tools);
}

[Fact(Skip = "docker is not available", SkipUnless = nameof(IsDockerAvailable))]
[Trait("Execution", "Manual")]
public async Task Sampling_Sse_EverythingServer()
{
int port = CreatePortNumber();

await using var fixture = new EverythingSseServerFixture(port);
await fixture.StartAsync();

var defaultConfig = new McpServerConfig
{
Id = "everything",
Name = "Everything",
TransportType = TransportTypes.Sse,
TransportOptions = [],
Location = $"http://localhost:{port}/sse"
};

int samplingHandlerCalls = 0;
var defaultOptions = new McpClientOptions
{
Capabilities = new()
{
Sampling = new()
{
SamplingHandler = (_, _, _) =>
{
samplingHandlerCalls++;
return Task.FromResult(new CreateMessageResult
{
Model = "test-model",
Role = "assistant",
Content = new Content
{
Type = "text",
Text = "Test response"
}
});
},
},
},
};

await using var client = await McpClientFactory.CreateAsync(
defaultConfig,
defaultOptions,
loggerFactory: LoggerFactory,
cancellationToken: TestContext.Current.CancellationToken);

// Call the server's sampleLLM tool which should trigger our sampling handler
var result = await client.CallToolAsync("sampleLLM", new Dictionary<string, object?>
{
["prompt"] = "Test prompt",
["maxTokens"] = 100
}, cancellationToken: TestContext.Current.CancellationToken);

// assert
Assert.NotNull(result);
var textContent = Assert.Single(result.Content);
Assert.Equal("text", textContent.Type);
Assert.False(string.IsNullOrEmpty(textContent.Text));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ public async Task DisposeAsyncCompletesImmediatelyWhenInvokedFromHandler(int num
var releaseHandler = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);

IAsyncDisposable? registration = null;
registration = client.RegisterNotificationHandler(NotificationName, async (notification, cancellationToken) =>
await using var _ = registration = client.RegisterNotificationHandler(NotificationName, async (notification, cancellationToken) =>
{
for (int i = 0; i < numberOfDisposals; i++)
{
Expand All @@ -240,9 +240,5 @@ public async Task DisposeAsyncCompletesImmediatelyWhenInvokedFromHandler(int num

await _server.SendNotificationAsync(NotificationName, TestContext.Current.CancellationToken);
await handlerRunning.Task;

ValueTask disposal = registration.DisposeAsync();
Assert.True(disposal.IsCompletedSuccessfully);
await disposal;
}
}
19 changes: 19 additions & 0 deletions tests/ModelContextProtocol.Tests/Server/MapMcpTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using Microsoft.AspNetCore.Builder;
using ModelContextProtocol.Tests.Utils;

namespace ModelContextProtocol.Tests.Server;

public class MapMcpTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper)
{
[Fact]
public async Task Allows_Customizing_Route()
{
await using var app = Builder.Build();
app.MapMcp("/mcp");
await app.StartAsync(TestContext.Current.CancellationToken);

using var httpClient = CreateHttpClient();
using var response = await httpClient.GetAsync("http://localhost/mcp/sse", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken);
response.EnsureSuccessStatusCode();
}
}
Loading
Loading