diff --git a/ModelContextProtocol.sln b/ModelContextProtocol.sln index 1f7a475b..064dc40d 100644 --- a/ModelContextProtocol.sln +++ b/ModelContextProtocol.sln @@ -50,6 +50,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "QuickstartWeatherServer", " EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "QuickstartClient", "samples\QuickstartClient\QuickstartClient.csproj", "{0D1552DC-E6ED-4AAC-5562-12F8352F46AA}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ModelContextProtocol.AspNetCore", "src\ModelContextProtocol.AspNetCore\ModelContextProtocol.AspNetCore.csproj", "{37B6A5E0-9995-497D-8B43-3BC6870CC716}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -92,6 +94,10 @@ Global {0D1552DC-E6ED-4AAC-5562-12F8352F46AA}.Debug|Any CPU.Build.0 = Debug|Any CPU {0D1552DC-E6ED-4AAC-5562-12F8352F46AA}.Release|Any CPU.ActiveCfg = Release|Any CPU {0D1552DC-E6ED-4AAC-5562-12F8352F46AA}.Release|Any CPU.Build.0 = Release|Any CPU + {37B6A5E0-9995-497D-8B43-3BC6870CC716}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {37B6A5E0-9995-497D-8B43-3BC6870CC716}.Debug|Any CPU.Build.0 = Debug|Any CPU + {37B6A5E0-9995-497D-8B43-3BC6870CC716}.Release|Any CPU.ActiveCfg = Release|Any CPU + {37B6A5E0-9995-497D-8B43-3BC6870CC716}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -107,6 +113,7 @@ Global {0C6D0512-D26D-63D3-5019-C5F7A657B28C} = {02EA681E-C7D8-13C7-8484-4AC65E1B71E8} {4653EB0C-8FC0-98F4-E9C8-220EDA7A69DF} = {02EA681E-C7D8-13C7-8484-4AC65E1B71E8} {0D1552DC-E6ED-4AAC-5562-12F8352F46AA} = {02EA681E-C7D8-13C7-8484-4AC65E1B71E8} + {37B6A5E0-9995-497D-8B43-3BC6870CC716} = {A2F1F52A-9107-4BF8-8C3F-2F6670E7D0AD} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {384A3888-751F-4D75-9AE5-587330582D89} diff --git a/samples/AspNetCoreSseServer/AspNetCoreSseServer.csproj b/samples/AspNetCoreSseServer/AspNetCoreSseServer.csproj index 35c2181f..c17cf9c4 100644 --- a/samples/AspNetCoreSseServer/AspNetCoreSseServer.csproj +++ b/samples/AspNetCoreSseServer/AspNetCoreSseServer.csproj @@ -8,6 +8,7 @@ + diff --git a/samples/AspNetCoreSseServer/McpEndpointRouteBuilderExtensions.cs b/samples/AspNetCoreSseServer/McpEndpointRouteBuilderExtensions.cs deleted file mode 100644 index 22306014..00000000 --- a/samples/AspNetCoreSseServer/McpEndpointRouteBuilderExtensions.cs +++ /dev/null @@ -1,62 +0,0 @@ -using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Server; -using ModelContextProtocol.Utils.Json; -using Microsoft.Extensions.Options; -using ModelContextProtocol.Protocol.Transport; - -namespace AspNetCoreSseServer; - -public static class McpEndpointRouteBuilderExtensions -{ - public static IEndpointConventionBuilder MapMcpSse(this IEndpointRouteBuilder endpoints) - { - SseResponseStreamTransport? transport = null; - var loggerFactory = endpoints.ServiceProvider.GetRequiredService(); - var mcpServerOptions = endpoints.ServiceProvider.GetRequiredService>(); - - var routeGroup = endpoints.MapGroup(""); - - routeGroup.MapGet("/sse", async (HttpResponse response, CancellationToken requestAborted) => - { - response.Headers.ContentType = "text/event-stream"; - response.Headers.CacheControl = "no-cache"; - - await using var localTransport = transport = new SseResponseStreamTransport(response.Body); - await using var server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider); - - try - { - var transportTask = transport.RunAsync(cancellationToken: requestAborted); - await server.RunAsync(cancellationToken: requestAborted); - await transportTask; - } - catch (OperationCanceledException) when (requestAborted.IsCancellationRequested) - { - // RequestAborted always triggers when the client disconnects before a complete response body is written, - // but this is how SSE connections are typically closed. - } - }); - - routeGroup.MapPost("/message", async context => - { - if (transport is null) - { - await Results.BadRequest("Connect to the /sse endpoint before sending messages.").ExecuteAsync(context); - return; - } - - var message = await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions, context.RequestAborted); - if (message is null) - { - await Results.BadRequest("No message in request body.").ExecuteAsync(context); - return; - } - - await transport.OnMessageReceivedAsync(message, context.RequestAborted); - context.Response.StatusCode = StatusCodes.Status202Accepted; - await context.Response.WriteAsync("Accepted"); - }); - - return routeGroup; - } -} diff --git a/samples/AspNetCoreSseServer/Program.cs b/samples/AspNetCoreSseServer/Program.cs index b66b2ce4..774957e8 100644 --- a/samples/AspNetCoreSseServer/Program.cs +++ b/samples/AspNetCoreSseServer/Program.cs @@ -1,10 +1,9 @@ -using AspNetCoreSseServer; +using ModelContextProtocol.AspNetCore; var builder = WebApplication.CreateBuilder(args); builder.Services.AddMcpServer().WithToolsFromAssembly(); var app = builder.Build(); -app.MapGet("/", () => "Hello World!"); -app.MapMcpSse(); +app.MapMcp(); app.Run(); diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs new file mode 100644 index 00000000..9ad1848c --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs @@ -0,0 +1,118 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Routing; +using Microsoft.AspNetCore.WebUtilities; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Server; +using ModelContextProtocol.Utils.Json; +using System.Collections.Concurrent; +using System.Security.Cryptography; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// Extension methods for to add MCP endpoints. +/// +public static class McpEndpointRouteBuilderExtensions +{ + /// + /// Sets up endpoints for handling MCP HTTP Streaming transport. + /// + /// The web application to attach MCP HTTP endpoints. + /// Provides an optional asynchronous callback for handling new MCP sessions. + /// Returns a builder for configuring additional endpoint conventions like authorization policies. + public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints, Func? runSession = null) + { + ConcurrentDictionary _sessions = new(StringComparer.Ordinal); + + var loggerFactory = endpoints.ServiceProvider.GetRequiredService(); + var mcpServerOptions = endpoints.ServiceProvider.GetRequiredService>(); + + var routeGroup = endpoints.MapGroup(""); + + routeGroup.MapGet("/sse", async context => + { + var response = context.Response; + var requestAborted = context.RequestAborted; + + response.Headers.ContentType = "text/event-stream"; + response.Headers.CacheControl = "no-cache"; + + var sessionId = MakeNewSessionId(); + await using var transport = new SseResponseStreamTransport(response.Body, $"/message?sessionId={sessionId}"); + if (!_sessions.TryAdd(sessionId, transport)) + { + throw new Exception($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created."); + } + await using var server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider); + + try + { + var transportTask = transport.RunAsync(cancellationToken: requestAborted); + runSession ??= RunSession; + + try + { + await runSession(context, server, requestAborted); + } + finally + { + await transport.DisposeAsync(); + await transportTask; + } + } + catch (OperationCanceledException) when (requestAborted.IsCancellationRequested) + { + // RequestAborted always triggers when the client disconnects before a complete response body is written, + // but this is how SSE connections are typically closed. + } + finally + { + _sessions.TryRemove(sessionId, out _); + } + }); + + routeGroup.MapPost("/message", async context => + { + if (!context.Request.Query.TryGetValue("sessionId", out var sessionId)) + { + await Results.BadRequest("Missing sessionId query parameter.").ExecuteAsync(context); + return; + } + + if (!_sessions.TryGetValue(sessionId.ToString(), out var transport)) + { + await Results.BadRequest($"Session {sessionId} not found.").ExecuteAsync(context); + return; + } + + var message = await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions, context.RequestAborted); + if (message is null) + { + await Results.BadRequest("No message in request body.").ExecuteAsync(context); + return; + } + + await transport.OnMessageReceivedAsync(message, context.RequestAborted); + context.Response.StatusCode = StatusCodes.Status202Accepted; + await context.Response.WriteAsync("Accepted"); + }); + + return routeGroup; + } + + private static Task RunSession(HttpContext httpContext, IMcpServer session, CancellationToken requestAborted) + => session.RunAsync(requestAborted); + + private static string MakeNewSessionId() + { + // 128 bits + Span buffer = stackalloc byte[16]; + RandomNumberGenerator.Fill(buffer); + return WebEncoders.Base64UrlEncode(buffer); + } +} diff --git a/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj b/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj new file mode 100644 index 00000000..5dd10dbf --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj @@ -0,0 +1,25 @@ + + + + net8.0 + enable + enable + true + true + ModelContextProtocol.AspNetCore + ASP.NET Core extensions for the C# Model Context Protocol (MCP) SDK. + README.md + + + + + + + + + + + + + + diff --git a/src/ModelContextProtocol.AspNetCore/README.md b/src/ModelContextProtocol.AspNetCore/README.md new file mode 100644 index 00000000..dd7a5909 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/README.md @@ -0,0 +1,54 @@ +# ASP.NET Core extensions for the MCP C# SDK + +[![NuGet preview version](https://img.shields.io/nuget/vpre/ModelContextProtocol.svg)](https://www.nuget.org/packages/ModelContextProtocol/absoluteLatest) + +The official C# SDK for the [Model Context Protocol](https://modelcontextprotocol.io/), enabling .NET applications, services, and libraries to implement and interact with MCP clients and servers. Please visit our [API documentation](https://modelcontextprotocol.github.io/csharp-sdk/api/ModelContextProtocol.html) for more details on available functionality. + +> [!NOTE] +> This project is in preview; breaking changes can be introduced without prior notice. + +## About MCP + +The Model Context Protocol (MCP) is an open protocol that standardizes how applications provide context to Large Language Models (LLMs). It enables secure integration between LLMs and various data sources and tools. + +For more information about MCP: + +- [Official Documentation](https://modelcontextprotocol.io/) +- [Protocol Specification](https://spec.modelcontextprotocol.io/) +- [GitHub Organization](https://github.com/modelcontextprotocol) + +## Installation + +To get started, install the package from NuGet + +``` +dotnet new web +dotnet add package ModelContextProtocol.AspNetcore --prerelease +``` + +## Getting Started + +```csharp +// Program.cs +using ModelContextProtocol; +using ModelContextProtocol.AspNetCore; + +var builder = WebApplication.CreateBuilder(args); +builder.WebHost.ConfigureKestrel(options => +{ + options.ListenLocalhost(3001); +}); +builder.Services.AddMcpServer().WithToolsFromAssembly(); +var app = builder.Build(); + +app.MapMcp(); + +app.Run(); + +[McpServerToolType] +public static class EchoTool +{ + [McpServerTool, Description("Echoes the message back to the client.")] + public static string Echo(string message) => $"hello {message}"; +} +``` diff --git a/src/ModelContextProtocol/Client/McpClient.cs b/src/ModelContextProtocol/Client/McpClient.cs index 3b81c7e2..b326f3c5 100644 --- a/src/ModelContextProtocol/Client/McpClient.cs +++ b/src/ModelContextProtocol/Client/McpClient.cs @@ -79,10 +79,9 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) { // Connect transport _sessionTransport = await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false); - InitializeSession(_sessionTransport); // We don't want the ConnectAsync token to cancel the session after we've successfully connected. // The base class handles cleaning up the session in DisposeAsync without our help. - StartSession(fullSessionCancellationToken: CancellationToken.None); + StartSession(_sessionTransport, fullSessionCancellationToken: CancellationToken.None); // Perform initialization sequence using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); @@ -142,13 +141,14 @@ await SendMessageAsync( /// public override async ValueTask DisposeUnsynchronizedAsync() { - if (_connectCts is not null) - { - await _connectCts.CancelAsync().ConfigureAwait(false); - } - try { + if (_connectCts is not null) + { + await _connectCts.CancelAsync().ConfigureAwait(false); + _connectCts.Dispose(); + } + await base.DisposeUnsynchronizedAsync().ConfigureAwait(false); } finally @@ -157,8 +157,6 @@ public override async ValueTask DisposeUnsynchronizedAsync() { await _sessionTransport.DisposeAsync().ConfigureAwait(false); } - - _connectCts?.Dispose(); } } } diff --git a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs index 06b111e6..3fbc4fea 100644 --- a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs @@ -337,7 +337,7 @@ public static IMcpServerBuilder WithStdioServerTransport(this IMcpServerBuilder Throw.IfNull(builder); builder.Services.AddSingleton(); - builder.Services.AddHostedService(); + builder.Services.AddHostedService(); builder.Services.AddSingleton(services => { @@ -350,18 +350,5 @@ public static IMcpServerBuilder WithStdioServerTransport(this IMcpServerBuilder return builder; } - - /// - /// Adds a server transport that uses SSE via a HttpListener for communication. - /// - /// The builder instance. - public static IMcpServerBuilder WithHttpListenerSseServerTransport(this IMcpServerBuilder builder) - { - Throw.IfNull(builder); - - builder.Services.AddSingleton(); - builder.Services.AddHostedService(); - return builder; - } #endregion } diff --git a/src/ModelContextProtocol/Hosting/McpServerMultiSessionHostedService.cs b/src/ModelContextProtocol/Hosting/McpServerMultiSessionHostedService.cs deleted file mode 100644 index 456f8944..00000000 --- a/src/ModelContextProtocol/Hosting/McpServerMultiSessionHostedService.cs +++ /dev/null @@ -1,43 +0,0 @@ -using Microsoft.Extensions.Hosting; -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Options; -using ModelContextProtocol.Protocol.Transport; -using ModelContextProtocol.Server; - -namespace ModelContextProtocol.Hosting; - -/// -/// Hosted service for a multi-session (i.e. HTTP) MCP server. -/// -internal sealed class McpServerMultiSessionHostedService : BackgroundService -{ - private readonly IServerTransport _serverTransport; - private readonly McpServerOptions _serverOptions; - private readonly ILoggerFactory _loggerFactory; - private readonly IServiceProvider _serviceProvider; - - public McpServerMultiSessionHostedService( - IServerTransport serverTransport, - IOptions serverOptions, - ILoggerFactory loggerFactory, - IServiceProvider serviceProvider) - { - _serverTransport = serverTransport; - _serverOptions = serverOptions.Value; - _loggerFactory = loggerFactory; - _serviceProvider = serviceProvider; - } - - /// - protected override async Task ExecuteAsync(CancellationToken stoppingToken) - { - while (await AcceptSessionAsync(stoppingToken).ConfigureAwait(false) is { } server) - { - // TODO: Track all running sessions and wait for all sessions to complete for graceful shutdown. - _ = server.RunAsync(stoppingToken); - } - } - - private Task AcceptSessionAsync(CancellationToken cancellationToken) - => McpServerFactory.AcceptAsync(_serverTransport, _serverOptions, _loggerFactory, _serviceProvider, cancellationToken); -} diff --git a/src/ModelContextProtocol/Hosting/McpServerSingleSessionHostedService.cs b/src/ModelContextProtocol/Hosting/StdioMcpServerHostedService.cs similarity index 77% rename from src/ModelContextProtocol/Hosting/McpServerSingleSessionHostedService.cs rename to src/ModelContextProtocol/Hosting/StdioMcpServerHostedService.cs index e86c288e..ae13e19d 100644 --- a/src/ModelContextProtocol/Hosting/McpServerSingleSessionHostedService.cs +++ b/src/ModelContextProtocol/Hosting/StdioMcpServerHostedService.cs @@ -6,7 +6,7 @@ namespace ModelContextProtocol.Hosting; /// /// Hosted service for a single-session (i.e stdio) MCP server. /// -internal sealed class McpServerSingleSessionHostedService(IMcpServer session) : BackgroundService +internal sealed class StdioMcpServerHostedService(IMcpServer session) : BackgroundService { /// protected override Task ExecuteAsync(CancellationToken stoppingToken) => session.RunAsync(stoppingToken); diff --git a/src/ModelContextProtocol/Protocol/Transport/HttpListenerServerProvider.cs b/src/ModelContextProtocol/Protocol/Transport/HttpListenerServerProvider.cs deleted file mode 100644 index 73a1dbd3..00000000 --- a/src/ModelContextProtocol/Protocol/Transport/HttpListenerServerProvider.cs +++ /dev/null @@ -1,210 +0,0 @@ -using System.Net; - -namespace ModelContextProtocol.Protocol.Transport; - -/// -/// HTTP server provider using HttpListener. -/// -internal sealed class HttpListenerServerProvider : IAsyncDisposable -{ - private static readonly byte[] s_accepted = "Accepted"u8.ToArray(); - - private const string SseEndpoint = "/sse"; - private const string MessageEndpoint = "/message"; - - private readonly HttpListener _listener; - private readonly CancellationTokenSource _shutdownTokenSource = new(); - private Task _listeningTask = Task.CompletedTask; - - private readonly TaskCompletionSource _completed = new(); - private int _outstandingOperations; - - private int _state; - private const int StateNotStarted = 0; - private const int StateRunning = 1; - private const int StateStopped = 2; - - /// - /// Creates a new instance of the HTTP server provider. - /// - /// The port to listen on - public HttpListenerServerProvider(int port) - { - if (port < 0) - { - throw new ArgumentOutOfRangeException(nameof(port)); - } - - _listener = new(); - _listener.Prefixes.Add($"http://localhost:{port}/"); - } - - public required Func OnSseConnectionAsync { get; set; } - public required Func> OnMessageAsync { get; set; } - - public void Start() - { - if (Interlocked.CompareExchange(ref _state, StateRunning, StateNotStarted) != StateNotStarted) - { - throw new ObjectDisposedException("Server may not be started twice."); - } - - // Start listening for connections - _listener.Start(); - - OperationAdded(); // for the listening task - _listeningTask = Task.Run(async () => - { - try - { - using var cts = CancellationTokenSource.CreateLinkedTokenSource(_shutdownTokenSource.Token); - cts.Token.Register(_listener.Stop); - while (!cts.IsCancellationRequested) - { - try - { - var context = await _listener.GetContextAsync().ConfigureAwait(false); - - // Process the request in a separate task - OperationAdded(); // for the processing task; decremented in ProcessRequestAsync - _ = Task.Run(() => ProcessRequestAsync(context, cts.Token), CancellationToken.None); - } - catch (Exception) - { - if (cts.IsCancellationRequested) - { - // Shutdown requested, exit gracefully - break; - } - } - } - } - finally - { - OperationCompleted(); // for the listening task - } - }, CancellationToken.None); - } - - /// - public async ValueTask DisposeAsync() - { - if (Interlocked.CompareExchange(ref _state, StateStopped, StateRunning) != StateRunning) - { - return; - } - - await _shutdownTokenSource.CancelAsync().ConfigureAwait(false); - _listener.Stop(); - await _listeningTask.ConfigureAwait(false); - await _completed.Task.ConfigureAwait(false); - } - - /// Gets a that completes when the server has finished its work. - public Task Completed => _completed.Task; - - private void OperationAdded() => Interlocked.Increment(ref _outstandingOperations); - - private void OperationCompleted() - { - if (Interlocked.Decrement(ref _outstandingOperations) == 0) - { - // All operations completed - _completed.TrySetResult(true); - } - } - - private async Task ProcessRequestAsync(HttpListenerContext context, CancellationToken cancellationToken) - { - var request = context.Request; - var response = context.Response; - try - { - if (request is null || response is null) - { - return; - } - - // Handle SSE connection - if (request.HttpMethod == "GET" && request.Url?.LocalPath == SseEndpoint) - { - await HandleSseConnectionAsync(context, cancellationToken).ConfigureAwait(false); - } - // Handle message POST - else if (request.HttpMethod == "POST" && request.Url?.LocalPath == MessageEndpoint) - { - await HandleMessageAsync(context, cancellationToken).ConfigureAwait(false); - } - else - { - // Not found - response.StatusCode = 404; - response.Close(); - } - } - catch (Exception) - { - try - { - response.StatusCode = 500; - response.Close(); - } - catch { /* Ignore errors during error handling */ } - } - finally - { - OperationCompleted(); - } - } - - private async Task HandleSseConnectionAsync(HttpListenerContext context, CancellationToken cancellationToken) - { - var response = context.Response; - - // Set SSE headers - response.ContentType = "text/event-stream"; - response.Headers.Add("Cache-Control", "no-cache"); - response.Headers.Add("Connection", "keep-alive"); - - // Keep the connection open until cancelled - try - { - await OnSseConnectionAsync(response.OutputStream, cancellationToken).ConfigureAwait(false); - } - catch (Exception) - { - } - finally - { - // Remove client on disconnect - try - { - response.Close(); - } - catch { /* Ignore errors during cleanup */ } - } - } - - private async Task HandleMessageAsync(HttpListenerContext context, CancellationToken cancellationToken) - { - var request = context.Request; - var response = context.Response; - - // Process the message asynchronously - if (await OnMessageAsync(request.InputStream, cancellationToken)) - { - // Return 202 Accepted - response.StatusCode = 202; - - // Write "accepted" response - await response.OutputStream.WriteAsync(s_accepted, cancellationToken).ConfigureAwait(false); - } - else - { - // Return 400 Bad Request - response.StatusCode = 400; - } - - response.Close(); - } -} diff --git a/src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerSessionTransport.cs deleted file mode 100644 index 6b2219e5..00000000 --- a/src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerSessionTransport.cs +++ /dev/null @@ -1,76 +0,0 @@ -using Microsoft.Extensions.Logging; -using ModelContextProtocol.Logging; -using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Utils; -using ModelContextProtocol.Utils.Json; -using System.Net; -using System.Text.Json; - -namespace ModelContextProtocol.Protocol.Transport; - -/// -/// Implements the MCP transport protocol using . -/// -internal sealed class HttpListenerSseServerSessionTransport : TransportBase -{ - private readonly string _serverName; - private readonly ILogger _logger; - private SseResponseStreamTransport _responseStreamTransport; - - private string EndpointName => $"Server (SSE) ({_serverName})"; - - public HttpListenerSseServerSessionTransport(string serverName, SseResponseStreamTransport responseStreamTransport, ILoggerFactory loggerFactory) - : base(loggerFactory) - { - Throw.IfNull(serverName); - - _serverName = serverName; - _responseStreamTransport = responseStreamTransport; - _logger = loggerFactory.CreateLogger(); - SetConnected(true); - } - - /// - public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) - { - if (!IsConnected) - { - _logger.TransportNotConnected(EndpointName); - throw new McpTransportException("Transport is not connected"); - } - - string id = "(no id)"; - if (message is IJsonRpcMessageWithId messageWithId) - { - id = messageWithId.Id.ToString(); - } - - try - { - if (_logger.IsEnabled(LogLevel.Debug)) - { - var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions.GetTypeInfo()); - _logger.TransportSendingMessage(EndpointName, id, json); - } - - await _responseStreamTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); - - _logger.TransportSentMessage(EndpointName, id); - } - catch (Exception ex) - { - _logger.TransportSendFailed(EndpointName, id, ex); - throw new McpTransportException("Failed to send message", ex); - } - } - - public Task OnMessageReceivedAsync(IJsonRpcMessage message, CancellationToken cancellationToken) - => WriteMessageAsync(message, cancellationToken); - - /// - public override ValueTask DisposeAsync() - { - SetConnected(false); - return default; - } -} diff --git a/src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerTransport.cs deleted file mode 100644 index 4e0aa7ed..00000000 --- a/src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerTransport.cs +++ /dev/null @@ -1,180 +0,0 @@ -using System.Net; -using System.Text.Json; -using Microsoft.Extensions.Logging; -using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Utils.Json; -using ModelContextProtocol.Logging; -using ModelContextProtocol.Server; -using ModelContextProtocol.Utils; -using System.Threading.Channels; - -namespace ModelContextProtocol.Protocol.Transport; - -/// -/// Implements the MCP transport protocol using . -/// -public sealed class HttpListenerSseServerTransport : IServerTransport, IAsyncDisposable -{ - private readonly string _serverName; - private readonly HttpListenerServerProvider _httpServerProvider; - private readonly ILoggerFactory _loggerFactory; - private readonly ILogger _logger; - - private readonly Channel _incomingSessions; - - private HttpListenerSseServerSessionTransport? _sessionTransport; - - private string EndpointName => $"Server (SSE) ({_serverName})"; - - /// - /// Initializes a new instance of the SseServerTransport class. - /// - /// The server options. - /// The port to listen on. - /// A logger factory for creating loggers. - public HttpListenerSseServerTransport(McpServerOptions serverOptions, int port, ILoggerFactory loggerFactory) - : this(GetServerName(serverOptions), port, loggerFactory) - { - } - - /// - /// Initializes a new instance of the SseServerTransport class. - /// - /// The name of the server. - /// The port to listen on. - /// A logger factory for creating loggers. - public HttpListenerSseServerTransport(string serverName, int port, ILoggerFactory loggerFactory) - { - Throw.IfNull(serverName); - - _serverName = serverName; - _loggerFactory = loggerFactory; - _logger = loggerFactory.CreateLogger(); - _httpServerProvider = new HttpListenerServerProvider(port) - { - OnSseConnectionAsync = OnSseConnectionAsync, - OnMessageAsync = OnMessageAsync, - }; - - // Until we support session IDs, there's no way to support more than one concurrent session. - // Any new SSE connection overwrites the old session and any new /messages go to the new session. - _incomingSessions = Channel.CreateBounded(new BoundedChannelOptions(1) - { - FullMode = BoundedChannelFullMode.DropOldest, - }); - - // REVIEW: We could add another layer of async for binding similar to Kestrel's IConnectionListenerFactory, - // but this wouldn't play well with a static factory method to accept new sessions. Ultimately, - // ASP.NET Core is not going to hand over binding to the MCP SDK, so I decided to just bind in the transport - // constructor for now. - _httpServerProvider.Start(); - } - - /// - public async Task AcceptAsync(CancellationToken cancellationToken = default) - { - while (await _incomingSessions.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) - { - if (_incomingSessions.Reader.TryRead(out var session)) - { - return session; - } - } - - return null; - } - - /// - public async ValueTask DisposeAsync() - { - _logger.TransportCleaningUp(EndpointName); - - await _httpServerProvider.DisposeAsync().ConfigureAwait(false); - _incomingSessions.Writer.TryComplete(); - - _logger.TransportCleanedUp(EndpointName); - } - - private async Task OnSseConnectionAsync(Stream responseStream, CancellationToken cancellationToken) - { - var sseResponseStreamTransport = new SseResponseStreamTransport(responseStream); - var sessionTransport = new HttpListenerSseServerSessionTransport(_serverName, sseResponseStreamTransport, _loggerFactory); - - await using (sseResponseStreamTransport.ConfigureAwait(false)) - await using (sseResponseStreamTransport.ConfigureAwait(false)) - { - _sessionTransport = sessionTransport; - await _incomingSessions.Writer.WriteAsync(sessionTransport).ConfigureAwait(false); - await sseResponseStreamTransport.RunAsync(cancellationToken).ConfigureAwait(false); - } - } - - /// - /// Handles HTTP messages received by the HTTP server provider. - /// - /// true if the message was accepted (return 202), false otherwise (return 400) - private async Task OnMessageAsync(Stream requestStream, CancellationToken cancellationToken) - { - string request; - IJsonRpcMessage? message = null; - - if (_logger.IsEnabled(LogLevel.Information)) - { - using var reader = new StreamReader(requestStream); - request = await reader.ReadToEndAsync(cancellationToken).ConfigureAwait(false); - message = JsonSerializer.Deserialize(request, McpJsonUtilities.DefaultOptions.GetTypeInfo()); - - _logger.TransportReceivedMessage(EndpointName, request); - } - else - { - request = "(Enable information-level logs to see the request)"; - } - - try - { - message ??= await JsonSerializer.DeserializeAsync(requestStream, McpJsonUtilities.DefaultOptions.GetTypeInfo()).ConfigureAwait(false); - if (message != null) - { - string messageId = "(no id)"; - if (message is IJsonRpcMessageWithId messageWithId) - { - messageId = messageWithId.Id.ToString(); - } - - _logger.TransportReceivedMessageParsed(EndpointName, messageId); - - if (_sessionTransport is null) - { - return false; - } - - await _sessionTransport.OnMessageReceivedAsync(message, cancellationToken).ConfigureAwait(false); - - _logger.TransportMessageWritten(EndpointName, messageId); - - return true; - } - else - { - _logger.TransportMessageParseUnexpectedType(EndpointName, request); - return false; - } - } - catch (JsonException ex) - { - _logger.TransportMessageParseFailed(EndpointName, request, ex); - return false; - } - } - - /// Validates the and extracts from it the server name to use. - private static string GetServerName(McpServerOptions serverOptions) - { - Throw.IfNull(serverOptions); - Throw.IfNull(serverOptions.ServerInfo); - Throw.IfNull(serverOptions.ServerInfo.Name); - - return serverOptions.ServerInfo.Name; - } -} diff --git a/src/ModelContextProtocol/Protocol/Transport/IServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/IServerTransport.cs deleted file mode 100644 index 0d1a9774..00000000 --- a/src/ModelContextProtocol/Protocol/Transport/IServerTransport.cs +++ /dev/null @@ -1,14 +0,0 @@ -namespace ModelContextProtocol.Protocol.Transport; - -/// -/// Represents a transport mechanism for MCP communication (from the server). -/// -public interface IServerTransport -{ - /// - /// Asynchronously accepts a transport session initiated by an MCP client and returns an interface for the duplex JSON-RPC message stream. - /// - /// Used to signal the cancellation of the asynchronous operation. - /// Returns an interface for the duplex JSON-RPC message stream. - Task AcceptAsync(CancellationToken cancellationToken = default); -} diff --git a/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs index 6c016f70..eafd1f61 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs @@ -15,8 +15,8 @@ namespace ModelContextProtocol.Protocol.Transport; /// The endpoint to send JSON-RPC messages to. Defaults to "/message". public sealed class SseResponseStreamTransport(Stream sseResponseStream, string messageEndpoint = "/message") : ITransport { - private readonly Channel _incomingChannel = CreateSingleItemChannel(); - private readonly Channel> _outgoingSseChannel = CreateSingleItemChannel>(); + private readonly Channel _incomingChannel = CreateBoundedChannel(); + private readonly Channel> _outgoingSseChannel = CreateBoundedChannel>(); private Task? _sseWriteTask; private Utf8JsonWriter? _jsonWriter; @@ -47,7 +47,10 @@ void WriteJsonRpcMessageToBuffer(SseItem item, IBufferWriter(null, "endpoint")); + if (!_outgoingSseChannel.Writer.TryWrite(new SseItem(null, "endpoint"))) + { + throw new InvalidOperationException($"You must call ${nameof(RunAsync)} before calling ${nameof(SendMessageAsync)}."); + } var sseItems = _outgoingSseChannel.Reader.ReadAllAsync(cancellationToken); return _sseWriteTask = SseFormatter.WriteAsync(sseItems, sseResponseStream, WriteJsonRpcMessageToBuffer, cancellationToken); @@ -73,7 +76,7 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca throw new InvalidOperationException($"Transport is not connected. Make sure to call {nameof(RunAsync)} first."); } - await _outgoingSseChannel.Writer.WriteAsync(new SseItem(message), cancellationToken).AsTask(); + await _outgoingSseChannel.Writer.WriteAsync(new SseItem(message), cancellationToken); } /// @@ -90,11 +93,11 @@ public async Task OnMessageReceivedAsync(IJsonRpcMessage message, CancellationTo throw new InvalidOperationException($"Transport is not connected. Make sure to call {nameof(RunAsync)} first."); } - await _incomingChannel.Writer.WriteAsync(message, cancellationToken).AsTask(); + await _incomingChannel.Writer.WriteAsync(message, cancellationToken); } - private static Channel CreateSingleItemChannel() => - Channel.CreateBounded(new BoundedChannelOptions(1) + private static Channel CreateBoundedChannel(int capacity = 1) => + Channel.CreateBounded(new BoundedChannelOptions(capacity) { SingleReader = true, SingleWriter = false, diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index d9e456c7..b9548319 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -12,30 +12,12 @@ namespace ModelContextProtocol.Server; /// internal sealed class McpServer : McpJsonRpcEndpoint, IMcpServer { - private readonly IServerTransport? _serverTransport; private readonly EventHandler? _toolsChangedDelegate; private readonly EventHandler? _promptsChangedDelegate; - private ITransport? _sessionTransport; + private ITransport _sessionTransport; private string _endpointName; - /// - /// Creates a new instance of . - /// - /// Transport to use for the server that is ready to accept new sessions asynchronously. - /// Configuration options for this server, including capabilities. - /// Make sure to accurately reflect exactly what capabilities the server supports and does not support. - /// Logger factory to use for logging - /// Optional service provider to use for dependency injection - /// - public McpServer(IServerTransport serverTransport, McpServerOptions options, ILoggerFactory? loggerFactory, IServiceProvider? serviceProvider) - : this(options, loggerFactory, serviceProvider) - { - Throw.IfNull(serverTransport); - - _serverTransport = serverTransport; - } - /// /// Creates a new instance of . /// @@ -46,27 +28,12 @@ public McpServer(IServerTransport serverTransport, McpServerOptions options, ILo /// Optional service provider to use for dependency injection /// public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? loggerFactory, IServiceProvider? serviceProvider) - : this(options, loggerFactory, serviceProvider) - { - Throw.IfNull(transport); - - _sessionTransport = transport; - InitializeSession(transport); - } - - /// - /// Creates a new instance of . - /// - /// Configuration options for this server, including capabilities. - /// Make sure to accurately reflect exactly what capabilities the server supports and does not support. - /// Logger factory to use for logging - /// Optional service provider to use for dependency injection - /// - private McpServer(McpServerOptions options, ILoggerFactory? loggerFactory, IServiceProvider? serviceProvider) : base(loggerFactory) { + Throw.IfNull(transport); Throw.IfNull(options); + _sessionTransport = transport; ServerOptions = options; Services = serviceProvider; _endpointName = $"Server ({options.ServerInfo.Name} {options.ServerInfo.Version})"; @@ -128,36 +95,13 @@ private McpServer(McpServerOptions options, ILoggerFactory? loggerFactory, IServ /// public override string EndpointName => _endpointName; - public async Task AcceptSessionAsync(CancellationToken cancellationToken = default) - { - // Below is effectively an assertion. The McpServerFactory should only use this with the IServerTransport constructor. - Throw.IfNull(_serverTransport); - - try - { - _sessionTransport = await _serverTransport.AcceptAsync(cancellationToken).ConfigureAwait(false); - - if (_sessionTransport is null) - { - throw new McpServerException("The server transport closed before a client started a new session."); - } - - InitializeSession(_sessionTransport); - } - catch (Exception e) - { - _logger.ServerInitializationError(EndpointName, e); - throw; - } - } - /// public async Task RunAsync(CancellationToken cancellationToken = default) { try { // Start processing messages - StartSession(fullSessionCancellationToken: cancellationToken); + StartSession(_sessionTransport, fullSessionCancellationToken: cancellationToken); await MessageProcessingTask.ConfigureAwait(false); } finally @@ -178,18 +122,7 @@ public override async ValueTask DisposeUnsynchronizedAsync() prompts.Changed -= _promptsChangedDelegate; } - try - { - await base.DisposeUnsynchronizedAsync().ConfigureAwait(false); - } - finally - { - if (_serverTransport is not null && _sessionTransport is not null) - { - // We created the _sessionTransport from the _serverTransport, so we own it. - await _sessionTransport.DisposeAsync().ConfigureAwait(false); - } - } + await base.DisposeUnsynchronizedAsync().ConfigureAwait(false); } private void SetPingHandler() @@ -208,7 +141,7 @@ private void SetInitializeHandler(McpServerOptions options) // Use the ClientInfo to update the session EndpointName for logging. _endpointName = $"{_endpointName}, Client ({ClientInfo?.Name} {ClientInfo?.Version})"; - GetSessionOrThrow().EndpointName = EndpointName; + GetSessionOrThrow().EndpointName = _endpointName; return Task.FromResult(new InitializeResult() { diff --git a/src/ModelContextProtocol/Server/McpServerFactory.cs b/src/ModelContextProtocol/Server/McpServerFactory.cs index 953b7430..3f140792 100644 --- a/src/ModelContextProtocol/Server/McpServerFactory.cs +++ b/src/ModelContextProtocol/Server/McpServerFactory.cs @@ -33,42 +33,4 @@ public static IMcpServer Create( return new McpServer(transport, serverOptions, loggerFactory, serviceProvider); } - - /// - /// Waits for the client to establish a new MCP session, then initializes a new instance of the class. - /// - /// Transport to use for the server that is ready to accept new MCP sessions asynchronously. - /// - /// Configuration options for this server, including capabilities. - /// Make sure to accurately reflect exactly what capabilities the server supports and does not support. - /// - /// Logger factory to use for logging - /// Optional service provider to create new instances. - /// Cancel waiting for a client to establish a new MCP session. - /// An . - /// is . - /// is . - public static async Task AcceptAsync( - IServerTransport serverTransport, - McpServerOptions serverOptions, - ILoggerFactory? loggerFactory = null, - IServiceProvider? serviceProvider = null, - CancellationToken cancellationToken = default) - { - Throw.IfNull(serverTransport); - Throw.IfNull(serverOptions); - - var mcpServer = new McpServer(serverTransport, serverOptions, loggerFactory, serviceProvider); - - try - { - await mcpServer.AcceptSessionAsync(cancellationToken).ConfigureAwait(false); - return mcpServer; - } - catch - { - await mcpServer.DisposeAsync().ConfigureAwait(false); - throw; - } - } } diff --git a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs b/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs index 98567d08..ac0337f9 100644 --- a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs +++ b/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs @@ -60,13 +60,8 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella /// protected Task? MessageProcessingTask { get; set; } - protected void InitializeSession(ITransport sessionTransport) - { - _session = new McpSession(sessionTransport, EndpointName, _requestHandlers, _notificationHandlers, _logger); - } - [MemberNotNull(nameof(MessageProcessingTask))] - protected void StartSession(CancellationToken fullSessionCancellationToken = default) + protected void StartSession(ITransport sessionTransport, CancellationToken fullSessionCancellationToken = default) { if (Interlocked.Exchange(ref _started, 1) != 0) { @@ -74,7 +69,8 @@ protected void StartSession(CancellationToken fullSessionCancellationToken = def } _sessionCts = CancellationTokenSource.CreateLinkedTokenSource(fullSessionCancellationToken); - MessageProcessingTask = GetSessionOrThrow().ProcessMessagesAsync(_sessionCts.Token); + _session = new McpSession(sessionTransport, EndpointName, _requestHandlers, _notificationHandlers, _logger); + MessageProcessingTask = _session.ProcessMessagesAsync(_sessionCts.Token); } public async ValueTask DisposeAsync() @@ -122,5 +118,5 @@ public virtual async ValueTask DisposeUnsynchronizedAsync() } protected McpSession GetSessionOrThrow() - => _session ?? throw new InvalidOperationException($"This should be unreachable from public API! Call {nameof(InitializeSession)} before sending messages."); + => _session ?? throw new InvalidOperationException($"This should be unreachable from public API! Call {nameof(StartSession)} before sending messages."); } diff --git a/tests/ModelContextProtocol.TestSseServer/ModelContextProtocol.TestSseServer.csproj b/tests/ModelContextProtocol.TestSseServer/ModelContextProtocol.TestSseServer.csproj index 544be2e0..6633ad4a 100644 --- a/tests/ModelContextProtocol.TestSseServer/ModelContextProtocol.TestSseServer.csproj +++ b/tests/ModelContextProtocol.TestSseServer/ModelContextProtocol.TestSseServer.csproj @@ -1,4 +1,4 @@ - + Exe @@ -15,6 +15,7 @@ + diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index b8751121..4bbc5bfc 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -1,7 +1,6 @@ -using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.AspNetCore; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; -using Microsoft.Extensions.Logging; using Serilog; using System.Text; using System.Text.Json; @@ -10,9 +9,7 @@ namespace ModelContextProtocol.TestSseServer; public class Program { - private static ILoggerFactory CreateLoggerFactory() => LoggerFactory.Create(ConfigureSerilog); - - public static void ConfigureSerilog(ILoggingBuilder loggingBuilder) + private static void ConfigureSerilog(ILoggingBuilder loggingBuilder) { Log.Logger = new LoggerConfiguration() .MinimumLevel.Verbose() // Capture all log levels @@ -27,22 +24,17 @@ public static void ConfigureSerilog(ILoggingBuilder loggingBuilder) public static Task Main(string[] args) => MainAsync(args); - public static async Task MainAsync(string[] args, ILoggerFactory? loggerFactory = null, CancellationToken cancellationToken = default) + private static void ConfigureOptions(McpServerOptions options) { - Console.WriteLine("Starting server..."); - - McpServerOptions options = new() + options.ServerInfo = new Implementation() { Name = "TestServer", Version = "1.0.0" }; + options.Capabilities = new ServerCapabilities() { - ServerInfo = new Implementation() { Name = "TestServer", Version = "1.0.0" }, - Capabilities = new ServerCapabilities() - { - Tools = new(), - Resources = new(), - Prompts = new(), - }, - ProtocolVersion = "2024-11-05", - ServerInstructions = "This is a test server with only stub functionality" + Tools = new(), + Resources = new(), + Prompts = new(), }; + options.ProtocolVersion = "2024-11-05"; + options.ServerInstructions = "This is a test server with only stub functionality"; Console.WriteLine("Registering handlers."); @@ -380,17 +372,31 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st } }, }; + } + + public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvider = null, CancellationToken cancellationToken = default) + { + Console.WriteLine("Starting server..."); - loggerFactory ??= CreateLoggerFactory(); - await using var httpListenerSseTransport = new HttpListenerSseServerTransport("TestServer", 3001, loggerFactory); - Console.WriteLine("Server running..."); + var builder = WebApplication.CreateSlimBuilder(args); + builder.WebHost.ConfigureKestrel(options => + { + options.ListenLocalhost(3001); + }); - // Each IMcpServer represents a new SSE session. - while (true) + ConfigureSerilog(builder.Logging); + if (loggerProvider is not null) { - var server = await McpServerFactory.AcceptAsync(httpListenerSseTransport, options, loggerFactory, cancellationToken: cancellationToken); - _ = server.RunAsync(cancellationToken: cancellationToken); + builder.Logging.AddProvider(loggerProvider); } + + builder.Services.AddMcpServer(ConfigureOptions); + + var app = builder.Build(); + + app.MapMcp(); + + await app.RunAsync(cancellationToken); } const string MCP_TINY_IMAGE = diff --git a/tests/ModelContextProtocol.TestSseServer/Properties/launchSettings.json b/tests/ModelContextProtocol.TestSseServer/Properties/launchSettings.json new file mode 100644 index 00000000..439d5361 --- /dev/null +++ b/tests/ModelContextProtocol.TestSseServer/Properties/launchSettings.json @@ -0,0 +1,12 @@ +{ + "profiles": { + "ModelContextProtocol.TestSseServer": { + "commandName": "Project", + "launchBrowser": false, + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + }, + "applicationUrl": "http://localhost:3001" + } + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index b3cdfea3..9b598d33 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -1,20 +1,20 @@ -using ModelContextProtocol.Client; -using Microsoft.Extensions.AI; -using OpenAI; -using ModelContextProtocol.Protocol.Types; +using Microsoft.Extensions.AI; +using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Messages; -using System.Text.Json; using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Tests.Utils; +using OpenAI; using System.Text.Encodings.Web; -using System.Text.Json.Serialization.Metadata; +using System.Text.Json; using System.Text.Json.Serialization; +using System.Text.Json.Serialization.Metadata; namespace ModelContextProtocol.Tests; public class ClientIntegrationTests : LoggedTest, IClassFixture { - private static readonly string? s_openAIKey = Environment.GetEnvironmentVariable("AI:OpenAI:ApiKey")!; + private static readonly string? s_openAIKey = Environment.GetEnvironmentVariable("AI:OpenAI:ApiKey"); public static bool NoOpenAIKeySet => string.IsNullOrWhiteSpace(s_openAIKey); diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsTransportsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsTransportsTests.cs index 1f8acb38..93d546b0 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsTransportsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsTransportsTests.cs @@ -19,18 +19,4 @@ public void WithStdioServerTransport_Sets_Transport() Assert.NotNull(transportType); Assert.Equal(typeof(StdioServerTransport), transportType.ImplementationType); } - - [Fact] - public void WithHttpListenerSseServerTransport_Sets_Transport() - { - var services = new ServiceCollection(); - var builder = new Mock(); - builder.SetupGet(b => b.Services).Returns(services); - - builder.Object.WithHttpListenerSseServerTransport(); - - var transportType = services.FirstOrDefault(s => s.ServiceType == typeof(IServerTransport)); - Assert.NotNull(transportType); - Assert.Equal(typeof(HttpListenerSseServerTransport), transportType.ImplementationType); - } } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs index 8eae35af..25c5123e 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs @@ -32,22 +32,16 @@ public async Task Create_Should_Initialize_With_Valid_Parameters() } [Fact] - public async Task Create_Throws_For_Null_ServerTransport() + public void Create_Throws_For_Null_ServerTransport() { // Arrange, Act & Assert Assert.Throws("transport", () => McpServerFactory.Create(null!, _options, LoggerFactory)); - - await Assert.ThrowsAsync("serverTransport", () => - McpServerFactory.AcceptAsync(null!, _options, LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); } [Fact] - public async Task Create_Throws_For_Null_Options() + public void Create_Throws_For_Null_Options() { // Arrange, Act & Assert Assert.Throws("serverOptions", () => McpServerFactory.Create(Mock.Of(), null!, LoggerFactory)); - - await Assert.ThrowsAsync("serverOptions", () => - McpServerFactory.AcceptAsync(Mock.Of(), null!, LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); } } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 30a32257..65ed07df 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -168,18 +168,6 @@ public async Task RequestRootsAsync_Should_SendRequest() await runTask; } - [Fact] - public async Task Throws_Exception_If_Not_Connected() - { - await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, _serviceProvider); - SetClientCapabilities(server, new ClientCapabilities { Roots = new RootsCapability() }); - _serverTransport.SetupGet(t => t.IsConnected).Returns(false); - - var action = async () => await server.RequestRootsAsync(new ListRootsRequestParams(), CancellationToken.None); - - await Assert.ThrowsAsync(action); - } - [Fact] public async Task Can_Handle_Ping_Requests() { diff --git a/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs b/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs index c3b0fbbf..75aefc08 100644 --- a/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs @@ -13,18 +13,11 @@ public class SseServerIntegrationTestFixture : IAsyncDisposable private readonly CancellationTokenSource _stopCts = new(); private readonly DelegatingTestOutputHelper _delegatingTestOutputHelper = new(); - private readonly ILoggerFactory _redirectingLoggerFactory; public McpServerConfig DefaultConfig { get; } public SseServerIntegrationTestFixture() { - _redirectingLoggerFactory = LoggerFactory.Create(builder => - { - Program.ConfigureSerilog(builder); - builder.AddProvider(new XunitLoggerProvider(_delegatingTestOutputHelper)); - }); - DefaultConfig = new McpServerConfig { Id = "test_server", @@ -34,7 +27,7 @@ public SseServerIntegrationTestFixture() Location = "http://localhost:3001/sse" }; - _serverTask = Program.MainAsync([], _redirectingLoggerFactory, _stopCts.Token); + _serverTask = Program.MainAsync([], new XunitLoggerProvider(_delegatingTestOutputHelper), _stopCts.Token); } public static McpClientOptions CreateDefaultClientOptions() => new() @@ -63,7 +56,6 @@ public async ValueTask DisposeAsync() catch (OperationCanceledException) { } - _redirectingLoggerFactory.Dispose(); _stopCts.Dispose(); } } diff --git a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs index 0f6d90d6..44befcd1 100644 --- a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs @@ -36,7 +36,7 @@ public async Task ConnectAndPing_Sse_TestServer() // Arrange // Act - var client = await GetClientAsync(); + await using var client = await GetClientAsync(); await client.PingAsync(TestContext.Current.CancellationToken); // Assert @@ -49,7 +49,7 @@ public async Task Connect_TestServer_ShouldProvideServerFields() // Arrange // Act - var client = await GetClientAsync(); + await using var client = await GetClientAsync(); // Assert Assert.NotNull(client.ServerCapabilities); @@ -62,7 +62,7 @@ public async Task ListTools_Sse_TestServer() // arrange // act - var client = await GetClientAsync(); + await using var client = await GetClientAsync(); var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); // assert @@ -75,7 +75,7 @@ public async Task CallTool_Sse_EchoServer() // arrange // act - var client = await GetClientAsync(); + await using var client = await GetClientAsync(); var result = await client.CallToolAsync( "echo", new Dictionary @@ -98,7 +98,7 @@ public async Task ListResources_Sse_TestServer() // arrange // act - var client = await GetClientAsync(); + await using var client = await GetClientAsync(); IList allResources = await client.ListResourcesAsync(TestContext.Current.CancellationToken); @@ -112,7 +112,7 @@ public async Task ReadResource_Sse_TextResource() // arrange // act - var client = await GetClientAsync(); + await using var client = await GetClientAsync(); // Odd numbered resources are text in the everything server (despite the docs saying otherwise) // 1 is index 0, which is "even" in the 0-based index // We copied this oddity to the test server @@ -131,7 +131,7 @@ public async Task ReadResource_Sse_BinaryResource() // arrange // act - var client = await GetClientAsync(); + await using var client = await GetClientAsync(); // Even numbered resources are binary in the everything server (despite the docs saying otherwise) // 2 is index 1, which is "odd" in the 0-based index // We copied this oddity to the test server @@ -150,7 +150,7 @@ public async Task ListPrompts_Sse_TestServer() // arrange // act - var client = await GetClientAsync(); + await using var client = await GetClientAsync(); var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); // assert @@ -167,7 +167,7 @@ public async Task GetPrompt_Sse_SimplePrompt() // arrange // act - var client = await GetClientAsync(); + await using var client = await GetClientAsync(); var result = await client.GetPromptAsync("simple_prompt", null, TestContext.Current.CancellationToken); // assert @@ -181,7 +181,7 @@ public async Task GetPrompt_Sse_ComplexPrompt() // arrange // act - var client = await GetClientAsync(); + await using var client = await GetClientAsync(); var arguments = new Dictionary { { "temperature", "0.7" }, @@ -200,7 +200,7 @@ public async Task GetPrompt_Sse_NonExistent_ThrowsException() // arrange // act - var client = await GetClientAsync(); + await using var client = await GetClientAsync(); await Assert.ThrowsAsync(() => client.GetPromptAsync("non_existent_prompt", null, TestContext.Current.CancellationToken)); } @@ -229,7 +229,7 @@ public async Task Sampling_Sse_TestServer() } }; }; - var client = await GetClientAsync(options); + await using var client = await GetClientAsync(options); #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously // Call the server's sampleLLM tool which should trigger our sampling handler @@ -246,4 +246,29 @@ public async Task Sampling_Sse_TestServer() Assert.Equal("text", textContent.Type); Assert.False(string.IsNullOrEmpty(textContent.Text)); } + + [Fact] + public async Task CallTool_Sse_EchoServer_Concurrently() + { + await using var client1 = await GetClientAsync(); + await using var client2 = await GetClientAsync(); + + for (int i = 0; i < 4; i++) + { + var client = (i % 2 == 0) ? client1 : client2; + var result = await client.CallToolAsync( + "echo", + new Dictionary + { + ["message"] = $"Hello MCP! {i}" + }, + TestContext.Current.CancellationToken + ); + + Assert.NotNull(result); + Assert.False(result.IsError); + var textContent = Assert.Single(result.Content, c => c.Type == "text"); + Assert.Equal($"Echo: Hello MCP! {i}", textContent.Text); + } + } } diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs index c1df1887..5857f3c4 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs @@ -6,7 +6,6 @@ using ModelContextProtocol.Tests.Utils; using ModelContextProtocol.Utils.Json; using System.IO.Pipelines; -using System.Runtime.InteropServices; using System.Text; using System.Text.Json; @@ -32,11 +31,9 @@ public StdioServerTransportTests(ITestOutputHelper testOutputHelper) }; } - [Fact] + [Fact(Skip="https://github.com/modelcontextprotocol/csharp-sdk/issues/143")] public async Task Constructor_Should_Initialize_With_Valid_Parameters() { - Assert.SkipWhen(RuntimeInformation.IsOSPlatform(OSPlatform.Linux), "https://github.com/modelcontextprotocol/csharp-sdk/issues/143"); - // Act await using var transport = new StdioServerTransport(_serverOptions); diff --git a/tests/ModelContextProtocol.Tests/Utils/LoggedTest.cs b/tests/ModelContextProtocol.Tests/Utils/LoggedTest.cs index f2e61f76..aa1ecbc2 100644 --- a/tests/ModelContextProtocol.Tests/Utils/LoggedTest.cs +++ b/tests/ModelContextProtocol.Tests/Utils/LoggedTest.cs @@ -13,14 +13,16 @@ public LoggedTest(ITestOutputHelper testOutputHelper) { CurrentTestOutputHelper = testOutputHelper, }; + LoggerProvider = new XunitLoggerProvider(_delegatingTestOutputHelper); LoggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder => { - builder.AddProvider(new XunitLoggerProvider(_delegatingTestOutputHelper)); + builder.AddProvider(LoggerProvider); }); } public ITestOutputHelper TestOutputHelper => _delegatingTestOutputHelper; public ILoggerFactory LoggerFactory { get; } + public ILoggerProvider LoggerProvider { get; } public virtual void Dispose() {