From 4ae26c04817d271f657c004b0c1ccb6c943cd370 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Thu, 27 Mar 2025 18:07:20 -0400 Subject: [PATCH 01/10] Enable graceful shutdown of servers --- README.md | 6 +- .../McpEndpointRouteBuilderExtensions.cs | 11 +- src/ModelContextProtocol/Client/McpClient.cs | 122 ++++++------- .../Hosting/McpServerHostedService.cs | 2 +- .../Transport/HttpListenerServerProvider.cs | 169 +++++++++--------- .../HttpListenerSseServerTransport.cs | 33 ++-- .../Protocol/Transport/IServerTransport.cs | 3 + .../Protocol/Transport/SseClientTransport.cs | 44 ++--- .../Transport/StdioClientTransport.cs | 1 - .../Transport/StdioServerTransport.cs | 157 ++++++++-------- src/ModelContextProtocol/Server/IMcpServer.cs | 9 +- src/ModelContextProtocol/Server/McpServer.cs | 38 ++-- .../Shared/McpJsonRpcEndpoint.cs | 10 +- .../Program.cs | 95 +++++----- .../Program.cs | 23 +-- .../Client/McpClientExtensionsTests.cs | 13 +- .../ClientIntegrationTests.cs | 39 ++-- .../McpServerBuilderExtensionsToolsTests.cs | 40 ++--- .../EverythingSseServerFixture.cs | 2 - .../Server/McpServerTests.cs | 48 ++--- .../Server/McpServerToolTests.cs | 3 +- .../TestAttributes.cs | 2 + .../Transport/SseClientTransportTests.cs | 6 +- .../Transport/StdioServerTransportTests.cs | 6 +- .../Utils/TestServerTransport.cs | 2 + 25 files changed, 391 insertions(+), 493 deletions(-) create mode 100644 tests/ModelContextProtocol.Tests/TestAttributes.cs diff --git a/README.md b/README.md index 63784d6e..502dd6b5 100644 --- a/README.md +++ b/README.md @@ -198,11 +198,7 @@ McpServerOptions options = new() }; await using IMcpServer server = McpServerFactory.Create(new StdioServerTransport("MyServer"), options); - -await server.StartAsync(); - -// Run until process is stopped by the client (parent process) -await Task.Delay(Timeout.Infinite); +await server.RunAsync(); ``` ## Acknowledgements diff --git a/samples/AspNetCoreSseServer/McpEndpointRouteBuilderExtensions.cs b/samples/AspNetCoreSseServer/McpEndpointRouteBuilderExtensions.cs index 9bf11dd8..8a37afba 100644 --- a/samples/AspNetCoreSseServer/McpEndpointRouteBuilderExtensions.cs +++ b/samples/AspNetCoreSseServer/McpEndpointRouteBuilderExtensions.cs @@ -10,7 +10,6 @@ public static class McpEndpointRouteBuilderExtensions { public static IEndpointConventionBuilder MapMcpSse(this IEndpointRouteBuilder endpoints) { - IMcpServer? server = null; SseResponseStreamTransport? transport = null; var loggerFactory = endpoints.ServiceProvider.GetRequiredService(); var mcpServerOptions = endpoints.ServiceProvider.GetRequiredService>(); @@ -19,17 +18,15 @@ public static IEndpointConventionBuilder MapMcpSse(this IEndpointRouteBuilder en routeGroup.MapGet("/sse", async (HttpResponse response, CancellationToken requestAborted) => { - await using var localTransport = transport = new SseResponseStreamTransport(response.Body); - await using var localServer = server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider); - - await localServer.StartAsync(requestAborted); - response.Headers.ContentType = "text/event-stream"; response.Headers.CacheControl = "no-cache"; + await using var localTransport = transport = new SseResponseStreamTransport(response.Body); + await using var server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider); + try { - await transport.RunAsync(requestAborted); + await transport.RunAsync(cancellationToken: requestAborted); } catch (OperationCanceledException) when (requestAborted.IsCancellationRequested) { diff --git a/src/ModelContextProtocol/Client/McpClient.cs b/src/ModelContextProtocol/Client/McpClient.cs index 1ae358ef..80126375 100644 --- a/src/ModelContextProtocol/Client/McpClient.cs +++ b/src/ModelContextProtocol/Client/McpClient.cs @@ -6,7 +6,6 @@ using ModelContextProtocol.Shared; using ModelContextProtocol.Utils.Json; using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; using System.Text.Json; namespace ModelContextProtocol.Client; @@ -17,7 +16,7 @@ internal sealed class McpClient : McpJsonRpcEndpoint, IMcpClient private readonly McpClientOptions _options; private readonly IClientTransport _clientTransport; - private volatile bool _isInitializing; + private int _connecting; /// /// Initializes a new instance of the class. @@ -74,33 +73,69 @@ public McpClient(IClientTransport transport, McpClientOptions options, McpServer /// public async Task ConnectAsync(CancellationToken cancellationToken = default) { - if (IsInitialized) - { - _logger.ClientAlreadyInitialized(EndpointName); - return; - } - - if (_isInitializing) + if (Interlocked.Exchange(ref _connecting, 1) != 0) { _logger.ClientAlreadyInitializing(EndpointName); - throw new InvalidOperationException("Client is already initializing"); + throw new InvalidOperationException("Client is already in use."); } - _isInitializing = true; + CancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + cancellationToken = CancellationTokenSource.Token; + try { - CancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - // Connect transport - await _clientTransport.ConnectAsync(CancellationTokenSource.Token).ConfigureAwait(false); + await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false); // Start processing messages - MessageProcessingTask = ProcessMessagesAsync(CancellationTokenSource.Token); + MessageProcessingTask = ProcessMessagesAsync(cancellationToken); // Perform initialization sequence - await InitializeAsync(CancellationTokenSource.Token).ConfigureAwait(false); + using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + initializationCts.CancelAfter(_options.InitializationTimeout); - IsInitialized = true; + try + { + // Send initialize request + var initializeResponse = await SendRequestAsync( + new JsonRpcRequest + { + Method = "initialize", + Params = new + { + protocolVersion = _options.ProtocolVersion, + capabilities = _options.Capabilities ?? new ClientCapabilities(), + clientInfo = _options.ClientInfo + } + }, + initializationCts.Token).ConfigureAwait(false); + + // Store server information + _logger.ServerCapabilitiesReceived(EndpointName, + capabilities: JsonSerializer.Serialize(initializeResponse.Capabilities, McpJsonUtilities.JsonContext.Default.ServerCapabilities), + serverInfo: JsonSerializer.Serialize(initializeResponse.ServerInfo, McpJsonUtilities.JsonContext.Default.Implementation)); + + ServerCapabilities = initializeResponse.Capabilities; + ServerInfo = initializeResponse.ServerInfo; + ServerInstructions = initializeResponse.Instructions; + + // Validate protocol version + if (initializeResponse.ProtocolVersion != _options.ProtocolVersion) + { + _logger.ServerProtocolVersionMismatch(EndpointName, _options.ProtocolVersion, initializeResponse.ProtocolVersion); + throw new McpClientException($"Server protocol version mismatch. Expected {_options.ProtocolVersion}, got {initializeResponse.ProtocolVersion}"); + } + + // Send initialized notification + await SendMessageAsync( + new JsonRpcNotification { Method = "notifications/initialized" }, + initializationCts.Token).ConfigureAwait(false); + } + catch (OperationCanceledException) when (initializationCts.IsCancellationRequested) + { + _logger.ClientInitializationTimeout(EndpointName); + throw new McpClientException("Initialization timed out"); + } } catch (Exception e) { @@ -108,58 +143,5 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) await CleanupAsync().ConfigureAwait(false); throw; } - finally - { - _isInitializing = false; - } - } - - private async Task InitializeAsync(CancellationToken cancellationToken) - { - using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - initializationCts.CancelAfter(_options.InitializationTimeout); - - try - { - // Send initialize request - var initializeResponse = await SendRequestAsync( - new JsonRpcRequest - { - Method = "initialize", - Params = new - { - protocolVersion = _options.ProtocolVersion, - capabilities = _options.Capabilities ?? new ClientCapabilities(), - clientInfo = _options.ClientInfo - } - }, - initializationCts.Token).ConfigureAwait(false); - - // Store server information - _logger.ServerCapabilitiesReceived(EndpointName, - capabilities: JsonSerializer.Serialize(initializeResponse.Capabilities, McpJsonUtilities.JsonContext.Default.ServerCapabilities), - serverInfo: JsonSerializer.Serialize(initializeResponse.ServerInfo, McpJsonUtilities.JsonContext.Default.Implementation)); - - ServerCapabilities = initializeResponse.Capabilities; - ServerInfo = initializeResponse.ServerInfo; - ServerInstructions = initializeResponse.Instructions; - - // Validate protocol version - if (initializeResponse.ProtocolVersion != _options.ProtocolVersion) - { - _logger.ServerProtocolVersionMismatch(EndpointName, _options.ProtocolVersion, initializeResponse.ProtocolVersion); - throw new McpClientException($"Server protocol version mismatch. Expected {_options.ProtocolVersion}, got {initializeResponse.ProtocolVersion}"); - } - - // Send initialized notification - await SendMessageAsync( - new JsonRpcNotification { Method = "notifications/initialized" }, - initializationCts.Token).ConfigureAwait(false); - } - catch (OperationCanceledException) when (initializationCts.IsCancellationRequested) - { - _logger.ClientInitializationTimeout(EndpointName); - throw new McpClientException("Initialization timed out"); - } } } diff --git a/src/ModelContextProtocol/Hosting/McpServerHostedService.cs b/src/ModelContextProtocol/Hosting/McpServerHostedService.cs index 8aa6218f..51c0d997 100644 --- a/src/ModelContextProtocol/Hosting/McpServerHostedService.cs +++ b/src/ModelContextProtocol/Hosting/McpServerHostedService.cs @@ -26,6 +26,6 @@ public McpServerHostedService(IMcpServer server) /// protected override async Task ExecuteAsync(CancellationToken stoppingToken) { - await _server.StartAsync(stoppingToken).ConfigureAwait(false); + await _server.RunAsync(cancellationToken: stoppingToken).ConfigureAwait(false); } } diff --git a/src/ModelContextProtocol/Protocol/Transport/HttpListenerServerProvider.cs b/src/ModelContextProtocol/Protocol/Transport/HttpListenerServerProvider.cs index 7b923c59..dbeb4e6c 100644 --- a/src/ModelContextProtocol/Protocol/Transport/HttpListenerServerProvider.cs +++ b/src/ModelContextProtocol/Protocol/Transport/HttpListenerServerProvider.cs @@ -1,22 +1,30 @@ using System.Net; -using ModelContextProtocol.Server; + +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously namespace ModelContextProtocol.Protocol.Transport; /// /// HTTP server provider using HttpListener. /// -internal class HttpListenerServerProvider : IDisposable +internal sealed class HttpListenerServerProvider : IAsyncDisposable { private static readonly byte[] s_accepted = "Accepted"u8.ToArray(); private const string SseEndpoint = "/sse"; private const string MessageEndpoint = "/message"; - private readonly int _port; - private HttpListener? _listener; - private CancellationTokenSource? _cts; - private bool _isRunning; + private readonly HttpListener _listener; + private readonly CancellationTokenSource _shutdownTokenSource = new(); + private Task _listeningTask = Task.CompletedTask; + + private readonly TaskCompletionSource _completed = new(); + private int _outstandingOperations; + + private int _state; + private const int StateNotStarted = 0; + private const int StateRunning = 1; + private const int StateStopped = 2; /// /// Creates a new instance of the HTTP server provider. @@ -24,84 +32,100 @@ internal class HttpListenerServerProvider : IDisposable /// The port to listen on public HttpListenerServerProvider(int port) { - _port = port; + if (port < 0) + { + throw new ArgumentOutOfRangeException(nameof(port)); + } + + _listener = new(); + _listener.Prefixes.Add($"http://localhost:{port}/"); } public required Func OnSseConnectionAsync { get; set; } public required Func> OnMessageAsync { get; set; } /// - public Task StartAsync(CancellationToken cancellationToken = default) + public async Task StartAsync(CancellationToken cancellationToken = default) { - if (_isRunning) - return Task.CompletedTask; + if (Interlocked.CompareExchange(ref _state, StateRunning, StateNotStarted) != StateNotStarted) + { + throw new ObjectDisposedException("Server may not be started twice."); + } - _cts = new CancellationTokenSource(); - _listener = new HttpListener(); - _listener.Prefixes.Add($"http://localhost:{_port}/"); + // Start listening for connections _listener.Start(); - _isRunning = true; - // Start listening for connections - _ = Task.Run(() => ListenForConnectionsAsync(_cts.Token), cancellationToken).ConfigureAwait(false); - return Task.CompletedTask; + OperationAdded(); // for the listening task + _listeningTask = Task.Run(async () => + { + try + { + using var cts = CancellationTokenSource.CreateLinkedTokenSource(_shutdownTokenSource.Token, cancellationToken); + cts.Token.Register(_listener.Stop); + while (!cts.IsCancellationRequested) + { + try + { + var context = await _listener.GetContextAsync().ConfigureAwait(false); + + // Process the request in a separate task + OperationAdded(); // for the processing task; decremented in ProcessRequestAsync + _ = Task.Run(() => ProcessRequestAsync(context, cts.Token), CancellationToken.None); + } + catch (Exception) + { + if (cts.IsCancellationRequested) + { + // Shutdown requested, exit gracefully + break; + } + } + } + } + finally + { + OperationCompleted(); // for the listening task + } + }, CancellationToken.None); } /// - public Task StopAsync(CancellationToken cancellationToken = default) + public async ValueTask DisposeAsync() { - if (!_isRunning) - return Task.CompletedTask; - - _cts?.Cancel(); - _listener?.Stop(); + if (Interlocked.CompareExchange(ref _state, StateStopped, StateRunning) != StateRunning) + { + return; + } - _isRunning = false; - return Task.CompletedTask; + await _shutdownTokenSource.CancelAsync().ConfigureAwait(false); + _listener.Stop(); + await _listeningTask.ConfigureAwait(false); } - private async Task ListenForConnectionsAsync(CancellationToken cancellationToken) - { - if (_listener == null) - { - throw new McpServerException("Listener not initialized"); - } + /// Gets a that completes when the server has finished its work. + public Task Completed => _completed.Task; - while (!cancellationToken.IsCancellationRequested) - { - try - { - var context = await _listener.GetContextAsync().ConfigureAwait(false); + private void OperationAdded() => Interlocked.Increment(ref _outstandingOperations); - // Process the request in a separate task - _ = Task.Run(() => ProcessRequestAsync(context, cancellationToken), cancellationToken); - } - catch (Exception) when (cancellationToken.IsCancellationRequested) - { - // Shutdown requested, exit gracefully - break; - } - catch (Exception) - { - // Log error but continue listening - if (!cancellationToken.IsCancellationRequested) - { - // Continue listening if not shutting down - continue; - } - } + private void OperationCompleted() + { + if (Interlocked.Decrement(ref _outstandingOperations) == 0) + { + // All operations completed + _completed.TrySetResult(true); } } private async Task ProcessRequestAsync(HttpListenerContext context, CancellationToken cancellationToken) { + var request = context.Request; + var response = context.Response; try { - var request = context.Request; - var response = context.Response; - - if (request == null) - throw new McpServerException("Request is null"); + if (request is null || response is null) + { + return; + } // Handle SSE connection if (request.HttpMethod == "GET" && request.Url?.LocalPath == SseEndpoint) @@ -124,11 +148,15 @@ private async Task ProcessRequestAsync(HttpListenerContext context, Cancellation { try { - context.Response.StatusCode = 500; - context.Response.Close(); + response.StatusCode = 500; + response.Close(); } catch { /* Ignore errors during error handling */ } } + finally + { + OperationCompleted(); + } } private async Task HandleSseConnectionAsync(HttpListenerContext context, CancellationToken cancellationToken) @@ -145,13 +173,8 @@ private async Task HandleSseConnectionAsync(HttpListenerContext context, Cancell { await OnSseConnectionAsync(response.OutputStream, cancellationToken).ConfigureAwait(false); } - catch (TaskCanceledException) - { - // Normal shutdown - } catch (Exception) { - // Client disconnected or other error } finally { @@ -186,20 +209,4 @@ private async Task HandleMessageAsync(HttpListenerContext context, CancellationT response.Close(); } - - - /// - public void Dispose() - { - Dispose(true); - GC.SuppressFinalize(this); - } - - /// - protected virtual void Dispose(bool disposing) - { - StopAsync().GetAwaiter().GetResult(); - _cts?.Dispose(); - _listener?.Close(); - } } diff --git a/src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerTransport.cs index c10b751d..44280608 100644 --- a/src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerTransport.cs @@ -7,6 +7,8 @@ using ModelContextProtocol.Server; using ModelContextProtocol.Utils; +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + namespace ModelContextProtocol.Protocol.Transport; /// @@ -28,7 +30,7 @@ public sealed class HttpListenerSseServerTransport : TransportBase, IServerTrans /// The port to listen on. /// A logger factory for creating loggers. public HttpListenerSseServerTransport(McpServerOptions serverOptions, int port, ILoggerFactory loggerFactory) - : this(GetServerName(serverOptions), port, loggerFactory) + : this(serverOptions?.ServerInfo?.Name!, port, loggerFactory) { } @@ -41,6 +43,8 @@ public HttpListenerSseServerTransport(McpServerOptions serverOptions, int port, public HttpListenerSseServerTransport(string serverName, int port, ILoggerFactory loggerFactory) : base(loggerFactory) { + Throw.IfNull(serverName); + _serverName = serverName; _logger = loggerFactory.CreateLogger(); _httpServerProvider = new HttpListenerServerProvider(port) @@ -51,10 +55,11 @@ public HttpListenerSseServerTransport(string serverName, int port, ILoggerFactor } /// - public Task StartListeningAsync(CancellationToken cancellationToken = default) - { - return _httpServerProvider.StartAsync(cancellationToken); - } + public Task StartListeningAsync(CancellationToken cancellationToken = default) => + _httpServerProvider.StartAsync(cancellationToken); + + /// + public Task Completion => _httpServerProvider.Completed; /// public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) @@ -92,20 +97,13 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio /// public override async ValueTask DisposeAsync() - { - await CleanupAsync(CancellationToken.None).ConfigureAwait(false); - GC.SuppressFinalize(this); - } - - private Task CleanupAsync(CancellationToken cancellationToken) { _logger.TransportCleaningUp(EndpointName); - _httpServerProvider.Dispose(); SetConnected(false); + await _httpServerProvider.DisposeAsync().ConfigureAwait(false); _logger.TransportCleanedUp(EndpointName); - return Task.CompletedTask; } private async Task OnSseConnectionAsync(Stream responseStream, CancellationToken cancellationToken) @@ -167,13 +165,4 @@ private async Task OnMessageAsync(Stream requestStream, CancellationToken return false; } } - - /// Validates the and extracts from it the server name to use. - private static string GetServerName(McpServerOptions serverOptions) - { - Throw.IfNull(serverOptions); - Throw.IfNull(serverOptions.ServerInfo); - - return serverOptions.ServerInfo.Name; - } } diff --git a/src/ModelContextProtocol/Protocol/Transport/IServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/IServerTransport.cs index 742c8942..e204202e 100644 --- a/src/ModelContextProtocol/Protocol/Transport/IServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/IServerTransport.cs @@ -10,4 +10,7 @@ public interface IServerTransport : ITransport /// /// Token to cancel the operation. Task StartListeningAsync(CancellationToken cancellationToken = default); + + /// Gets a that will complete when the server transport has completed all work. + Task Completion { get; } } diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs index 5d72d0c0..654b3a6b 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs @@ -21,7 +21,7 @@ public sealed class SseClientTransport : TransportBase, IClientTransport private readonly SseClientTransportOptions _options; private readonly Uri _sseEndpoint; private Uri? _messageEndpoint; - private CancellationTokenSource? _connectionCts; + private readonly CancellationTokenSource _connectionCts; private Task? _receiveTask; private readonly ILogger _logger; private readonly McpServerConfig _serverConfig; @@ -82,7 +82,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) } // Start message receiving loop - _receiveTask = ReceiveMessagesAsync(_connectionCts!.Token); + _receiveTask = ReceiveMessagesAsync(_connectionCts.Token); _logger.TransportReadingMessages(EndpointName); @@ -165,19 +165,25 @@ public override async Task SendMessageAsync( } } - /// - public async Task CloseAsync() + private async Task CloseAsync() { - if (_connectionCts != null) + try { - await _connectionCts.CancelAsync().ConfigureAwait(false); - _connectionCts.Dispose(); - _connectionCts = null; - } - if (_receiveTask != null) - await _receiveTask.ConfigureAwait(false); + if (!_connectionCts.IsCancellationRequested) + { + await _connectionCts.CancelAsync().ConfigureAwait(false); + _connectionCts.Dispose(); + } - SetConnected(false); + if (_receiveTask != null) + { + await _receiveTask.ConfigureAwait(false); + } + } + finally + { + SetConnected(false); + } } /// @@ -185,19 +191,17 @@ public override async ValueTask DisposeAsync() { try { - await CloseAsync().ConfigureAwait(false); + if (_ownsHttpClient) + { + _httpClient?.Dispose(); + } + + await CloseAsync(); } catch (Exception) { // Ignore exceptions on close } - - if (_ownsHttpClient) - _httpClient?.Dispose(); - - _connectionCts?.Dispose(); - - GC.SuppressFinalize(this); } internal Uri? MessageEndpoint => _messageEndpoint; diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs index 49e43301..00f4dfb4 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs @@ -183,7 +183,6 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio public override async ValueTask DisposeAsync() { await CleanupAsync(CancellationToken.None).ConfigureAwait(false); - GC.SuppressFinalize(this); } private async Task ReadMessagesAsync(CancellationToken cancellationToken) diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs index e1f517e5..05cd5a28 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs @@ -25,9 +25,12 @@ public sealed class StdioServerTransport : TransportBase, IServerTransport private readonly TextReader _stdInReader; private readonly Stream _stdOutStream; - private SemaphoreSlim _sendLock = new(1, 1); - private Task? _readTask; - private CancellationTokenSource? _shutdownCts; + private readonly SemaphoreSlim _sendLock = new(1, 1); + private readonly CancellationTokenSource _shutdownCts = new(); + + private Task _readLoopCompleted = Task.CompletedTask; + private TaskCompletionSource _serverCompleted = new(); + private int _disposed = 0; private string EndpointName => $"Server (stdio) ({_serverName})"; @@ -44,8 +47,8 @@ public sealed class StdioServerTransport : TransportBase, IServerTransport /// to , as that will interfere with the transport's output. /// /// - public StdioServerTransport(McpServerOptions serverOptions, ILoggerFactory? loggerFactory = null) - : this(GetServerName(serverOptions), loggerFactory) + public StdioServerTransport(IOptions serverOptions, ILoggerFactory? loggerFactory = null) + : this(serverOptions?.Value!, loggerFactory: loggerFactory) { } @@ -62,9 +65,17 @@ public StdioServerTransport(McpServerOptions serverOptions, ILoggerFactory? logg /// to , as that will interfere with the transport's output. /// /// - public StdioServerTransport(IOptions serverOptions, ILoggerFactory? loggerFactory = null) - : this(GetServerName(serverOptions.Value), loggerFactory) + public StdioServerTransport(McpServerOptions serverOptions, ILoggerFactory? loggerFactory = null) + : this(GetServerName(serverOptions), loggerFactory: loggerFactory) + { + } + + private static string GetServerName(McpServerOptions serverOptions) { + Throw.IfNull(serverOptions); + Throw.IfNull(serverOptions.ServerInfo); + Throw.IfNull(serverOptions.ServerInfo.Name); + return serverOptions.ServerInfo.Name; } /// @@ -80,25 +91,17 @@ public StdioServerTransport(IOptions serverOptions, ILoggerFac /// to , as that will interfere with the transport's output. /// /// - public StdioServerTransport(string serverName, ILoggerFactory? loggerFactory = null) - : base(loggerFactory) + public StdioServerTransport(string serverName, ILoggerFactory? loggerFactory) + : this(serverName, stdinStream: null, stdoutStream: null, loggerFactory) { - Throw.IfNull(serverName); - - _serverName = serverName; - _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; - - // Get raw console streams and wrap them with UTF-8 encoding - _stdInReader = new StreamReader(Console.OpenStandardInput(), Encoding.UTF8); - _stdOutStream = new BufferedStream(Console.OpenStandardOutput()); } /// /// Initializes a new instance of the class with explicit input/output streams. /// /// The name of the server. - /// The input TextReader to use. - /// The output TextWriter to use. + /// The input to use as standard input. If , will be used. + /// The output to use as standard output. If , will be used. /// Optional logger factory used for logging employed by the transport. /// is . /// @@ -106,35 +109,34 @@ public StdioServerTransport(string serverName, ILoggerFactory? loggerFactory = n /// This constructor is useful for testing scenarios where you want to redirect input/output. /// /// - public StdioServerTransport(string serverName, Stream stdinStream, Stream stdoutStream, ILoggerFactory? loggerFactory = null) + public StdioServerTransport(string serverName, Stream? stdinStream = null, Stream? stdoutStream = null, ILoggerFactory? loggerFactory = null) : base(loggerFactory) { Throw.IfNull(serverName); - Throw.IfNull(stdinStream); - Throw.IfNull(stdoutStream); _serverName = serverName; _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; - - _stdInReader = new StreamReader(stdinStream, Encoding.UTF8); - _stdOutStream = stdoutStream; + + _stdInReader = new StreamReader(stdinStream ?? Console.OpenStandardInput(), Encoding.UTF8); + _stdOutStream = stdoutStream ?? new BufferedStream(Console.OpenStandardOutput()); } /// public Task StartListeningAsync(CancellationToken cancellationToken = default) { _logger.LogDebug("Starting StdioServerTransport listener for {EndpointName}", EndpointName); - - _shutdownCts = new CancellationTokenSource(); - - _readTask = Task.Run(async () => await ReadMessagesAsync(_shutdownCts.Token).ConfigureAwait(false), CancellationToken.None); SetConnected(true); + _readLoopCompleted = Task.Run(ReadMessagesAsync, _shutdownCts.Token); + _logger.LogDebug("StdioServerTransport now connected for {EndpointName}", EndpointName); return Task.CompletedTask; } + /// Gets a that completes when the server transport has finished its work. + public Task Completion => _serverCompleted.Task; + /// public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) { @@ -169,33 +171,26 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio } } - /// - public override async ValueTask DisposeAsync() - { - await CleanupAsync(CancellationToken.None).ConfigureAwait(false); - GC.SuppressFinalize(this); - } - - private async Task ReadMessagesAsync(CancellationToken cancellationToken) + private async Task ReadMessagesAsync() { + CancellationToken shutdownToken = _shutdownCts.Token; try { _logger.TransportEnteringReadMessagesLoop(EndpointName); - while (!cancellationToken.IsCancellationRequested) + while (!shutdownToken.IsCancellationRequested) { _logger.TransportWaitingForMessage(EndpointName); - var reader = _stdInReader; - var line = await reader.ReadLineAsync(cancellationToken).ConfigureAwait(false); - if (line == null) - { - _logger.TransportEndOfStream(EndpointName); - break; - } - + var line = await _stdInReader.ReadLineAsync(shutdownToken).ConfigureAwait(false); if (string.IsNullOrWhiteSpace(line)) { + if (line is null) + { + _logger.TransportEndOfStream(EndpointName); + break; + } + continue; } @@ -204,8 +199,7 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken) try { - var message = JsonSerializer.Deserialize(line, _jsonOptions.GetTypeInfo()); - if (message != null) + if (JsonSerializer.Deserialize(line, _jsonOptions.GetTypeInfo()) is { } message) { string messageId = "(no id)"; if (message is IJsonRpcMessageWithId messageWithId) @@ -213,7 +207,8 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken) messageId = messageWithId.Id.ToString(); } _logger.TransportReceivedMessageParsed(EndpointName, messageId); - await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); + + await WriteMessageAsync(message, shutdownToken).ConfigureAwait(false); _logger.TransportMessageWritten(EndpointName, messageId); } else @@ -227,76 +222,70 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken) // Continue reading even if we fail to parse a message } } + _logger.TransportExitingReadMessagesLoop(EndpointName); } - catch (OperationCanceledException) + catch (OperationCanceledException oce) { _logger.TransportReadMessagesCancelled(EndpointName); - // Normal shutdown + _serverCompleted.TrySetCanceled(oce.CancellationToken); } catch (Exception ex) { _logger.TransportReadMessagesFailed(EndpointName, ex); + _serverCompleted.TrySetException(ex); } finally { - await CleanupAsync(cancellationToken).ConfigureAwait(false); + _serverCompleted.TrySetResult(true); } } - private async Task CleanupAsync(CancellationToken cancellationToken) + /// + public override async ValueTask DisposeAsync() { - _logger.TransportCleaningUp(EndpointName); - - if (_shutdownCts is { } shutdownCts) + if (Interlocked.Exchange(ref _disposed, 1) != 0) { - await shutdownCts.CancelAsync().ConfigureAwait(false); - shutdownCts.Dispose(); - - _shutdownCts = null; + return; } - if (_readTask is { } readTask) + try { + _logger.TransportCleaningUp(EndpointName); + + // Signal to the stdin reading loop to stop. + await _shutdownCts.CancelAsync().ConfigureAwait(false); + _shutdownCts.Dispose(); + + // Dispose of stdin/out. Cancellation may not be able to wake up operations + // synchronously blocked in a syscall; we need to forcefully close the handle / file descriptor. + _stdInReader?.Dispose(); + _stdOutStream?.Dispose(); + + // Make sure the work has quiesced. try { _logger.TransportWaitingForReadTask(EndpointName); - await readTask.WaitAsync(TimeSpan.FromSeconds(5), cancellationToken).ConfigureAwait(false); + await _readLoopCompleted.ConfigureAwait(false); + _logger.TransportReadTaskCleanedUp(EndpointName); } catch (TimeoutException) { _logger.TransportCleanupReadTaskTimeout(EndpointName); - // Continue with cleanup } catch (OperationCanceledException) { _logger.TransportCleanupReadTaskCancelled(EndpointName); - // Ignore cancellation } catch (Exception ex) { _logger.TransportCleanupReadTaskFailed(EndpointName, ex); } - finally - { - _logger.TransportReadTaskCleanedUp(EndpointName); - _readTask = null; - } } - - _stdInReader?.Dispose(); - _stdOutStream?.Dispose(); - - SetConnected(false); - _logger.TransportCleanedUp(EndpointName); - } - - /// Validates the and extracts from it the server name to use. - private static string GetServerName(McpServerOptions serverOptions) - { - Throw.IfNull(serverOptions); - Throw.IfNull(serverOptions.ServerInfo); - - return serverOptions.ServerInfo.Name; + finally + { + SetConnected(false); + _logger.TransportCleanedUp(EndpointName); + } } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Server/IMcpServer.cs b/src/ModelContextProtocol/Server/IMcpServer.cs index 08d4d990..e311b04a 100644 --- a/src/ModelContextProtocol/Server/IMcpServer.cs +++ b/src/ModelContextProtocol/Server/IMcpServer.cs @@ -8,11 +8,6 @@ namespace ModelContextProtocol.Server; /// public interface IMcpServer : IAsyncDisposable { - /// - /// Gets a value indicating whether the server has been initialized. - /// - bool IsInitialized { get; } - /// /// Gets the capabilities supported by the client. /// @@ -48,9 +43,9 @@ public interface IMcpServer : IAsyncDisposable void AddNotificationHandler(string method, Func handler); /// - /// Starts the server and begins listening for client requests. + /// Runs the server, listening for and handling client requests. /// - Task StartAsync(CancellationToken cancellationToken = default); + Task RunAsync(Func? onInitialized = null, CancellationToken cancellationToken = default); /// /// Sends a generic JSON-RPC request to the client. diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index 7f425b1f..5a8f50f6 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -1,5 +1,4 @@ using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Logging; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; @@ -13,11 +12,10 @@ namespace ModelContextProtocol.Server; /// internal sealed class McpServer : McpJsonRpcEndpoint, IMcpServer { - private readonly IServerTransport? _serverTransport; private readonly string _serverDescription; private readonly EventHandler? _toolsChangedDelegate; - private volatile bool _isInitializing; + private volatile bool _ran; /// /// Creates a new instance of . @@ -33,7 +31,6 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? { Throw.IfNull(options); - _serverTransport = transport as IServerTransport; ServerOptions = options; Services = serviceProvider; _serverDescription = $"{options.ServerInfo.Name} {options.ServerInfo.Version}"; @@ -52,7 +49,6 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? tools.Changed += _toolsChangedDelegate; } - IsInitialized = true; return Task.CompletedTask; }); @@ -68,6 +64,7 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? public ServerCapabilities? ServerCapabilities { get; set; } + /// public ClientCapabilities? ClientCapabilities { get; set; } /// @@ -84,35 +81,40 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? $"Server ({_serverDescription}), Client ({ClientInfo?.Name} {ClientInfo?.Version})"; /// - public async Task StartAsync(CancellationToken cancellationToken = default) + public async Task RunAsync(Func? onInitialized = null, CancellationToken cancellationToken = default) { - if (_isInitializing) - { - _logger.ServerAlreadyInitializing(EndpointName); - throw new InvalidOperationException("Server is already initializing"); - } - _isInitializing = true; - - if (IsInitialized) + if (_ran) { _logger.ServerAlreadyInitializing(EndpointName); - return; + throw new InvalidOperationException("Server is already running."); } + _ran = true; try { CancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - if (_serverTransport is not null) + IServerTransport? serverTransport = Transport as IServerTransport; + if (serverTransport is not null) { // Start listening for messages - await _serverTransport.StartListeningAsync(CancellationTokenSource.Token).ConfigureAwait(false); + await serverTransport.StartListeningAsync(CancellationTokenSource.Token).ConfigureAwait(false); } // Start processing messages MessageProcessingTask = ProcessMessagesAsync(CancellationTokenSource.Token); - // Unlike McpClient, we're not done initializing until we've received a message from the client, so we don't set IsInitialized here + // Invoke any post-initialization logic. + if (onInitialized is not null) + { + await onInitialized(CancellationTokenSource.Token).ConfigureAwait(false); + } + + // Wait for transport completion. + if (serverTransport is not null) + { + await serverTransport.Completion; + } } catch (Exception e) { diff --git a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs b/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs index 90824c58..7052953d 100644 --- a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs +++ b/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs @@ -49,11 +49,6 @@ protected McpJsonRpcEndpoint(ITransport transport, ILoggerFactory? loggerFactory _logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance; } - /// - /// Gets whether the endpoint is initialized and ready to process messages. - /// - public bool IsInitialized { get; set; } - /// /// Gets the name of the endpoint for logging and debug purposes. /// @@ -320,7 +315,6 @@ public void AddNotificationHandler(string method, Func @@ -369,7 +363,9 @@ protected virtual async Task CleanupAsync() _logger.CleaningUpEndpoint(EndpointName); if (CancellationTokenSource != null) + { await CancellationTokenSource.CancelAsync().ConfigureAwait(false); + } if (MessageProcessingTask != null) { @@ -393,8 +389,6 @@ protected virtual async Task CleanupAsync() await _transport.DisposeAsync().ConfigureAwait(false); CancellationTokenSource?.Dispose(); - IsInitialized = false; - _logger.EndpointCleanedUp(EndpointName); } } diff --git a/tests/ModelContextProtocol.TestServer/Program.cs b/tests/ModelContextProtocol.TestServer/Program.cs index e0558447..923a2b9d 100644 --- a/tests/ModelContextProtocol.TestServer/Program.cs +++ b/tests/ModelContextProtocol.TestServer/Program.cs @@ -4,6 +4,7 @@ using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using Serilog; +using System.Collections.Concurrent; using System.Text; using System.Text.Json; @@ -49,59 +50,50 @@ private static async Task Main(string[] args) using var loggerFactory = CreateLoggerFactory(); await using IMcpServer server = McpServerFactory.Create(new StdioServerTransport("TestServer", loggerFactory), options, loggerFactory); - Log.Logger.Information("Server initialized."); - - await server.StartAsync(); - - Log.Logger.Information("Server started."); - - // everything server sends random log level messages every 15 seconds - int loggingSeconds = 0; - Random random = Random.Shared; - var loggingLevels = Enum.GetValues().ToList(); - - // Run until process is stopped by the client (parent process) - while (true) + Log.Logger.Information("Server running..."); + await server.RunAsync(async cancellationToken => { - await Task.Delay(5000); - if (_minimumLoggingLevel is not null) - { - loggingSeconds += 5; + var loggingLevels = Enum.GetValues(); - // Send random log messages every 15 seconds - if (loggingSeconds >= 15) + // Run until process is stopped by the client (parent process) + while (true) + { + await Task.Delay(1000, cancellationToken); + try { - var logLevelIndex = random.Next(loggingLevels.Count); - var logLevel = loggingLevels[logLevelIndex]; - await server.SendMessageAsync(new JsonRpcNotification() + // Send random log messages every few seconds + if (_minimumLoggingLevel is not null) { - Method = NotificationMethods.LoggingMessageNotification, - Params = new LoggingMessageNotificationParams + var logLevel = loggingLevels[Random.Shared.Next(loggingLevels.Length)]; + await server.SendMessageAsync(new JsonRpcNotification() { - Level = logLevel, - Data = JsonSerializer.Deserialize("\"Random log message\"") - } - }); - } - } + Method = NotificationMethods.LoggingMessageNotification, + Params = new LoggingMessageNotificationParams + { + Level = logLevel, + Data = JsonSerializer.Deserialize("\"Random log message\"") + } + }, cancellationToken); + } - // Snapshot the subscribed resources, rather than locking while sending notifications - List resources; - lock (_subscribedResourcesLock) - { - resources = _subscribedResources.ToList(); - } - - foreach (var resource in resources) - { - ResourceUpdatedNotificationParams notificationParams = new() { Uri = resource }; - await server.SendMessageAsync(new JsonRpcNotification() + // Snapshot the subscribed resources, rather than locking while sending notifications + foreach (var resource in _subscribedResources) + { + ResourceUpdatedNotificationParams notificationParams = new() { Uri = resource.Key }; + await server.SendMessageAsync(new JsonRpcNotification() + { + Method = NotificationMethods.ResourceUpdatedNotification, + Params = notificationParams + }, cancellationToken); + } + } + catch (Exception ex) { - Method = NotificationMethods.ResourceUpdatedNotification, - Params = notificationParams - }); + Log.Logger.Error(ex, "Error sending log message"); + break; + } } - } + }); } private static ToolsCapability ConfigureTools() @@ -312,8 +304,7 @@ private static LoggingCapability ConfigureLogging() }; } - private static readonly HashSet _subscribedResources = new(); - private static readonly object _subscribedResourcesLock = new(); + private static readonly ConcurrentDictionary _subscribedResources = new(); private static ResourcesCapability ConfigureResources() { @@ -451,10 +442,7 @@ private static ResourcesCapability ConfigureResources() throw new McpServerException("Invalid resource URI"); } - lock (_subscribedResourcesLock) - { - _subscribedResources.Add(request.Params.Uri); - } + _subscribedResources.TryAdd(request.Params.Uri, true); return Task.FromResult(new EmptyResult()); }, @@ -471,10 +459,7 @@ private static ResourcesCapability ConfigureResources() throw new McpServerException("Invalid resource URI"); } - lock (_subscribedResourcesLock) - { - _subscribedResources.Remove(request.Params.Uri); - } + _subscribedResources.Remove(request.Params.Uri, out _); return Task.FromResult(new EmptyResult()); }, diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index 598b4adf..98d5dfaf 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -44,8 +44,6 @@ public static async Task MainAsync(string[] args, ILoggerFactory? loggerFactory ServerInstructions = "This is a test server with only stub functionality" }; - IMcpServer? server = null; - Console.WriteLine("Registering handlers."); #region Helped method @@ -186,7 +184,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st { throw new McpServerException("Missing required arguments 'prompt' and 'maxTokens'"); } - var sampleResult = await server!.RequestSamplingAsync(CreateRequestSamplingParams(prompt?.ToString() ?? "", "sampleLLM", Convert.ToInt32(maxTokens?.ToString())), + var sampleResult = await request.Server.RequestSamplingAsync(CreateRequestSamplingParams(prompt?.ToString() ?? "", "sampleLLM", Convert.ToInt32(maxTokens?.ToString())), cancellationToken); return new CallToolResponse() @@ -384,23 +382,10 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st }; loggerFactory ??= CreateLoggerFactory(); - server = McpServerFactory.Create(new HttpListenerSseServerTransport("TestServer", 3001, loggerFactory), options, loggerFactory); - - Console.WriteLine("Server initialized."); - - await server.StartAsync(cancellationToken); - - Console.WriteLine("Server started."); + await using IMcpServer server = McpServerFactory.Create(new HttpListenerSseServerTransport("TestServer", 3001, loggerFactory), options, loggerFactory); - try - { - // Run until process is stopped by the client (parent process) or test - await Task.Delay(Timeout.Infinite, cancellationToken); - } - finally - { - await server.DisposeAsync(); - } + Console.WriteLine("Server running..."); + await server.RunAsync(cancellationToken: cancellationToken); } const string MCP_TINY_IMAGE = diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs index 03d0e993..ec4808ff 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs @@ -13,6 +13,8 @@ public class McpClientExtensionsTests private Pipe _clientToServerPipe = new(); private Pipe _serverToClientPipe = new(); private readonly IMcpServer _server; + private readonly CancellationTokenSource _cts; + private readonly Task _serverTask; public McpClientExtensionsTests() { @@ -25,19 +27,22 @@ public McpClientExtensionsTests() sc.AddSingleton(McpServerTool.Create((int i) => $"{name} Result {i}", name)); } _server = sc.BuildServiceProvider().GetRequiredService(); + + _cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); + _serverTask = _server.RunAsync(cancellationToken: _cts.Token); } - public ValueTask DisposeAsync() + public async ValueTask DisposeAsync() { + _cts.Cancel(); _clientToServerPipe.Writer.Complete(); _serverToClientPipe.Writer.Complete(); - return _server.DisposeAsync(); + await _serverTask; + await _server.DisposeAsync(); } private async Task CreateMcpClientForServer() { - await _server.StartAsync(TestContext.Current.CancellationToken); - var serverStdinWriter = new StreamWriter(_clientToServerPipe.Writer.AsStream()); var serverStdoutReader = new StreamReader(_serverToClientPipe.Reader.AsStream()); diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index 69321675..9ecfdd57 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -252,21 +252,17 @@ public async Task SubscribeResource_Stdio() var clientId = "test_server"; // act - int counter = 0; + TaskCompletionSource tcs = new(); await using var client = await _fixture.CreateClientAsync(clientId); client.AddNotificationHandler(NotificationMethods.ResourceUpdatedNotification, (notification) => { var notificationParams = JsonSerializer.Deserialize(notification.Params!.ToString() ?? string.Empty); - ++counter; + tcs.TrySetResult(true); return Task.CompletedTask; }); await client.SubscribeToResourceAsync("test://static/resource/1", CancellationToken.None); - // notifications happen every 5 seconds, so we wait for 10 seconds to ensure we get at least one notification - await Task.Delay(10000, TestContext.Current.CancellationToken); - - // assert - Assert.True(counter > 0); + await tcs.Task; } // Not supported by "everything" server version on npx @@ -277,32 +273,26 @@ public async Task UnsubscribeResource_Stdio() var clientId = "test_server"; // act - int counter = 0; + TaskCompletionSource receivedNotification = new(); await using var client = await _fixture.CreateClientAsync(clientId); client.AddNotificationHandler(NotificationMethods.ResourceUpdatedNotification, (notification) => { var notificationParams = JsonSerializer.Deserialize(notification.Params!.ToString() ?? string.Empty); - ++counter; + receivedNotification.TrySetResult(true); return Task.CompletedTask; }); await client.SubscribeToResourceAsync("test://static/resource/1", CancellationToken.None); - // notifications happen every 5 seconds, so we wait for 10 seconds to ensure we get at least one notification - await Task.Delay(10000, TestContext.Current.CancellationToken); - - // reset counter - int counterAfterSubscribe = counter; + // wait until we received a notification + await receivedNotification.Task; // unsubscribe await client.UnsubscribeFromResourceAsync("test://static/resource/1", CancellationToken.None); - counter = 0; + receivedNotification = new(); - // notifications happen every 5 seconds, so we wait for 10 seconds to ensure we would've gotten at least one notification - await Task.Delay(10000, TestContext.Current.CancellationToken); - - // assert - Assert.True(counterAfterSubscribe > 0); - Assert.Equal(0, counter); + // wait a bit to validate we don't receive another. this is best effort only; + // false negatives are possible. + await Assert.ThrowsAsync(() => receivedNotification.Task.WaitAsync(TimeSpan.FromSeconds(1), TestContext.Current.CancellationToken)); } [Theory] @@ -556,7 +546,7 @@ public async Task SetLoggingLevel_ReceivesLoggingMessages(string clientId) Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping, }; - int logCounter = 0; + TaskCompletionSource receivedNotification = new(); await using var client = await _fixture.CreateClientAsync(clientId); client.AddNotificationHandler(NotificationMethods.LoggingMessageNotification, (notification) => { @@ -564,16 +554,15 @@ public async Task SetLoggingLevel_ReceivesLoggingMessages(string clientId) jsonSerializerOptions); if (loggingMessageNotificationParameters is not null) { - ++logCounter; + receivedNotification.TrySetResult(true); } return Task.CompletedTask; }); // act await client.SetLoggingLevel(LoggingLevel.Debug, CancellationToken.None); - await Task.Delay(16000, TestContext.Current.CancellationToken); // assert - Assert.True(logCounter > 0); + await receivedNotification.Task; } } diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 3ae31301..18c08576 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -24,6 +24,8 @@ public class McpServerBuilderExtensionsToolsTests : LoggedTest, IAsyncDisposable private readonly ServiceProvider _serviceProvider; private readonly IMcpServerBuilder _builder; private readonly IMcpServer _server; + private readonly CancellationTokenSource _cts; + private readonly Task _serverTask; public McpServerBuilderExtensionsToolsTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) @@ -35,19 +37,24 @@ public McpServerBuilderExtensionsToolsTests(ITestOutputHelper testOutputHelper) _builder = sc.AddMcpServer().WithTools(); _serviceProvider = sc.BuildServiceProvider(); _server = _serviceProvider.GetRequiredService(); + + _cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); + _serverTask = _server.RunAsync(cancellationToken: _cts.Token); } - public ValueTask DisposeAsync() + public async ValueTask DisposeAsync() { + await _cts.CancelAsync(); + _cts.Dispose(); + _clientToServerPipe.Writer.Complete(); _serverToClientPipe.Writer.Complete(); - return _serviceProvider.DisposeAsync(); + + await _serviceProvider.DisposeAsync(); } private async Task CreateMcpClientForServer() { - await _server.StartAsync(TestContext.Current.CancellationToken); - var serverStdinWriter = new StreamWriter(_clientToServerPipe.Writer.AsStream()); var serverStdoutReader = new StreamReader(_serverToClientPipe.Reader.AsStream()); @@ -105,15 +112,13 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T var stdinPipe = new Pipe(); var stdoutPipe = new Pipe(); - try - { - var transport = new StdioServerTransport($"TestServer_{i}", stdinPipe.Reader.AsStream(), stdoutPipe.Writer.AsStream()); - var server = McpServerFactory.Create(transport, options, loggerFactory, _serviceProvider); - - await server.StartAsync(TestContext.Current.CancellationToken); + var transport = new StdioServerTransport($"TestServer_{i}", stdinPipe.Reader.AsStream(), stdoutPipe.Writer.AsStream()); + var server = McpServerFactory.Create(transport, options, loggerFactory, _serviceProvider); - var serverStdinWriter = new StreamWriter(stdinPipe.Writer.AsStream()); - var serverStdoutReader = new StreamReader(stdoutPipe.Reader.AsStream()); + await server.RunAsync(async cancellationToken => + { + using var serverStdinWriter = new StreamWriter(stdinPipe.Writer.AsStream()); + using var serverStdoutReader = new StreamReader(stdoutPipe.Reader.AsStream()); var serverConfig = new McpServerConfig() { @@ -125,9 +130,9 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T var client = await McpClientFactory.CreateAsync( serverConfig, createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader), - cancellationToken: TestContext.Current.CancellationToken); + cancellationToken: cancellationToken); - var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); + var tools = await client.ListToolsAsync(cancellationToken); Assert.Equal(11, tools.Count); McpClientTool echoTool = tools.First(t => t.Name == "Echo"); @@ -141,12 +146,7 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T McpClientTool doubleEchoTool = tools.First(t => t.Name == "double_echo"); Assert.Equal("double_echo", doubleEchoTool.Name); Assert.Equal("Echoes the input back to the client.", doubleEchoTool.Description); - } - finally - { - stdinPipe.Writer.Complete(); - stdoutPipe.Writer.Complete(); - } + }, TestContext.Current.CancellationToken); } } diff --git a/tests/ModelContextProtocol.Tests/EverythingSseServerFixture.cs b/tests/ModelContextProtocol.Tests/EverythingSseServerFixture.cs index fdd701d7..d5fff66f 100644 --- a/tests/ModelContextProtocol.Tests/EverythingSseServerFixture.cs +++ b/tests/ModelContextProtocol.Tests/EverythingSseServerFixture.cs @@ -56,8 +56,6 @@ public async ValueTask DisposeAsync() // Log the exception but don't throw await Console.Error.WriteLineAsync($"Error stopping Docker container: {ex.Message}"); } - - GC.SuppressFinalize(this); } private static bool CheckIsDockerAvailable() diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 3f4dd1d8..469d0afc 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -84,51 +84,38 @@ public async Task Constructor_Does_Not_Throw_For_Null_ServiceProvider() } [Fact] - public async Task StartAsync_Should_Throw_InvalidOperationException_If_Already_Initializing() + public async Task RunAsync_Should_Throw_InvalidOperationException_If_Already_Running() { // Arrange await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, _serviceProvider); - var task = server.StartAsync(TestContext.Current.CancellationToken); + var task = server.RunAsync(cancellationToken: TestContext.Current.CancellationToken); // Act & Assert - await Assert.ThrowsAsync(() => server.StartAsync(TestContext.Current.CancellationToken)); + await Assert.ThrowsAsync(() => server.RunAsync(cancellationToken: TestContext.Current.CancellationToken)); await task; } [Fact] - public async Task StartAsync_Should_Do_Nothing_If_Already_Initialized() - { - // Arrange - await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, _serviceProvider); - SetInitialized(server, true); - - await server.StartAsync(TestContext.Current.CancellationToken); - - // Assert - _serverTransport.Verify(t => t.StartListeningAsync(It.IsAny()), Times.Never); - } - - [Fact] - public async Task StartAsync_ShouldStartListening() + public async Task RunAsync_ShouldStartListening() { // Arrange await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, _serviceProvider); // Act - await server.StartAsync(TestContext.Current.CancellationToken); + await server.RunAsync(cancellationToken: TestContext.Current.CancellationToken); // Assert _serverTransport.Verify(t => t.StartListeningAsync(It.IsAny()), Times.Once); } [Fact] - public async Task StartAsync_Sets_Initialized_After_Transport_Responses_Initialized_Notification() + public async Task RunAsync_Sets_Initialized_After_Transport_Responses_Initialized_Notification() { await using var transport = new TestServerTransport(); await using var server = McpServerFactory.Create(transport, _options, LoggerFactory, _serviceProvider); - await server.StartAsync(TestContext.Current.CancellationToken); + await server.RunAsync(cancellationToken: TestContext.Current.CancellationToken); // Send initialized notification await transport.SendMessageAsync(new JsonRpcNotification @@ -137,8 +124,6 @@ await transport.SendMessageAsync(new JsonRpcNotification }, TestContext.Current.CancellationToken); await Task.Delay(50, TestContext.Current.CancellationToken); - - Assert.True(server.IsInitialized); } [Fact] @@ -162,7 +147,7 @@ public async Task RequestSamplingAsync_Should_SendRequest() await using var server = McpServerFactory.Create(transport, _options, LoggerFactory, _serviceProvider); SetClientCapabilities(server, new ClientCapabilities { Sampling = new SamplingCapability() }); - await server.StartAsync(TestContext.Current.CancellationToken); + await server.RunAsync(cancellationToken: TestContext.Current.CancellationToken); // Act var result = await server.RequestSamplingAsync(new CreateMessageRequestParams { Messages = [] }, CancellationToken.None); @@ -191,7 +176,7 @@ public async Task RequestRootsAsync_Should_SendRequest() await using var transport = new TestServerTransport(); await using var server = McpServerFactory.Create(transport, _options, LoggerFactory, _serviceProvider); SetClientCapabilities(server, new ClientCapabilities { Roots = new RootsCapability() }); - await server.StartAsync(TestContext.Current.CancellationToken); + await server.RunAsync(cancellationToken: TestContext.Current.CancellationToken); // Act var result = await server.RequestRootsAsync(new ListRootsRequestParams(), CancellationToken.None); @@ -555,7 +540,7 @@ private async Task Can_Handle_Requests(ServerCapabilities? serverCapabilities, s await using var server = McpServerFactory.Create(transport, options, LoggerFactory, _serviceProvider); - await server.StartAsync(); + await server.RunAsync(); var receivedMessage = new TaskCompletionSource(); @@ -633,13 +618,6 @@ private static void SetClientCapabilities(IMcpServer server, ClientCapabilities property.SetValue(server, capabilities); } - private static void SetInitialized(IMcpServer server, bool isInitialized) - { - PropertyInfo? property = server.GetType().GetProperty("IsInitialized", BindingFlags.Public | BindingFlags.Instance); - Assert.NotNull(property); - property.SetValue(server, isInitialized); - } - private sealed class TestServerForIChatClient(bool supportsSampling) : IMcpServer { public ClientCapabilities? ClientCapabilities => @@ -675,8 +653,6 @@ public Task SendRequestAsync(JsonRpcRequest request, CancellationToken can public ValueTask DisposeAsync() => default; - public bool IsInitialized => throw new NotImplementedException(); - public Implementation? ClientInfo => throw new NotImplementedException(); public McpServerOptions ServerOptions => throw new NotImplementedException(); public IServiceProvider? Services => throw new NotImplementedException(); @@ -684,9 +660,7 @@ public void AddNotificationHandler(string method, Func throw new NotImplementedException(); - public Task StartAsync(CancellationToken cancellationToken = default) => + public Task RunAsync(Func? onInitialized = null, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - - public object? GetService(Type serviceType, object? serviceKey = null) => null; } } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index 1ce4b36c..3f066dd5 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -1,5 +1,4 @@ -using Microsoft.Extensions.AI; -using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using Moq; diff --git a/tests/ModelContextProtocol.Tests/TestAttributes.cs b/tests/ModelContextProtocol.Tests/TestAttributes.cs new file mode 100644 index 00000000..4edbce6e --- /dev/null +++ b/tests/ModelContextProtocol.Tests/TestAttributes.cs @@ -0,0 +1,2 @@ +// Uncomment to disable parallel test execution +//[assembly: CollectionBehavior(DisableTestParallelization = true)] \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs index 761c55e7..84f848fa 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs @@ -90,7 +90,7 @@ public async Task ConnectAsync_Should_Connect_Successfully() if (!firstCall) { Assert.True(transport.IsConnected); - await transport.CloseAsync(); + await transport.DisposeAsync(); } firstCall = false; @@ -148,7 +148,7 @@ public async Task ConnectAsync_Throws_If_Already_Connected() var exception = await Assert.ThrowsAsync(action); Assert.Equal("Transport is already connected", exception.Message); tcsDone.SetResult(); - await transport.CloseAsync(); + await transport.DisposeAsync(); await task; } @@ -305,7 +305,7 @@ public async Task CloseAsync_Should_Dispose_Resources() { await using var transport = new SseClientTransport(_transportOptions, _serverConfig, LoggerFactory); - await transport.CloseAsync(); + await transport.DisposeAsync(); Assert.False(transport.IsConnected); } diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs index 94ae6272..03994ee7 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs @@ -1,4 +1,5 @@ -using ModelContextProtocol.Protocol.Messages; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; @@ -45,9 +46,10 @@ public void Constructor_Throws_For_Null_Options() { Assert.Throws("serverName", () => new StdioServerTransport((string)null!)); + Assert.Throws("serverOptions", () => new StdioServerTransport((IOptions)null!)); Assert.Throws("serverOptions", () => new StdioServerTransport((McpServerOptions)null!)); Assert.Throws("serverOptions.ServerInfo", () => new StdioServerTransport(new McpServerOptions() { ServerInfo = null! })); - Assert.Throws("serverName", () => new StdioServerTransport(new McpServerOptions() { ServerInfo = new() { Name = null!, Version = "" } })); + Assert.Throws("serverOptions.ServerInfo.Name", () => new StdioServerTransport(new McpServerOptions() { ServerInfo = new() { Name = null!, Version = "" } })); } [Fact] diff --git a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs index f38d933e..4fca42a3 100644 --- a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs +++ b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs @@ -12,6 +12,8 @@ public class TestServerTransport : IServerTransport public bool IsConnected => _isStarted; + public Task Completion => Task.CompletedTask; + public ChannelReader MessageReader => _messageChannel; public List SentMessages { get; } = []; From fbe1dbe5e4528e4cb9750a5a8dd4559192214b75 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Fri, 28 Mar 2025 09:37:23 -0700 Subject: [PATCH 02/10] Call IMcpServer.RunAsync in AspNetCoreSseServer --- .../AspNetCoreSseServer/McpEndpointRouteBuilderExtensions.cs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/samples/AspNetCoreSseServer/McpEndpointRouteBuilderExtensions.cs b/samples/AspNetCoreSseServer/McpEndpointRouteBuilderExtensions.cs index 8a37afba..18ce5d03 100644 --- a/samples/AspNetCoreSseServer/McpEndpointRouteBuilderExtensions.cs +++ b/samples/AspNetCoreSseServer/McpEndpointRouteBuilderExtensions.cs @@ -26,7 +26,9 @@ public static IEndpointConventionBuilder MapMcpSse(this IEndpointRouteBuilder en try { + var serverTask = server.RunAsync(cancellationToken: requestAborted); await transport.RunAsync(cancellationToken: requestAborted); + await serverTask; } catch (OperationCanceledException) when (requestAborted.IsCancellationRequested) { From 4b7aabb679453bf49b7767ffb918a9c315a48985 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Fri, 28 Mar 2025 16:58:29 -0700 Subject: [PATCH 03/10] Refactor transports to enable graceful shutdown - This also paves the way for better multi-session support - We should definitely rethink names for the transport API - For now, I kept the names similar as possible, so we can focus on the API shape --- src/ModelContextProtocol/Client/McpClient.cs | 59 +-- .../Client/McpClientFactory.cs | 18 +- .../McpServerBuilderExtensions.Transports.cs | 19 +- .../McpServerServiceCollectionExtension.cs | 11 +- .../Hosting/McpServerHostedService.cs | 31 -- .../McpServerMultiSessionHostedService.cs | 43 ++ .../McpServerSingleSessionHostedService.cs | 13 + .../ModelContextProtocol.csproj | 2 +- .../Transport/HttpListenerServerProvider.cs | 6 +- .../HttpListenerSseServerSessionTransport.cs | 76 ++++ .../HttpListenerSseServerTransport.cs | 94 ++--- .../Protocol/Transport/IClientTransport.cs | 7 +- .../Protocol/Transport/IServerTransport.cs | 12 +- .../Transport/SseClientSessionTransport.cs | 332 ++++++++++++++++ .../Protocol/Transport/SseClientTransport.cs | 303 +------------- .../Transport/StdioClientStreamTransport.cs | 321 +++++++++++++++ .../Transport/StdioClientTransport.cs | 296 +------------- .../Transport/StdioServerTransport.cs | 28 +- .../Protocol/Transport/TransportBase.cs | 2 +- src/ModelContextProtocol/Server/IMcpServer.cs | 10 +- src/ModelContextProtocol/Server/McpServer.cs | 124 ++++-- .../Server/McpServerFactory.cs | 52 ++- .../Shared/McpJsonRpcEndpoint.cs | 368 +++--------------- src/ModelContextProtocol/Shared/McpSession.cs | 299 ++++++++++++++ .../Shared/NotificationHandlers.cs | 16 + .../Shared/RequestHandlers.cs | 31 ++ .../Program.cs | 76 ++-- .../Program.cs | 11 +- .../Client/McpClientExtensionsTests.cs | 29 +- .../Client/McpClientFactoryTests.cs | 5 +- .../McpServerBuilderExtensionsToolsTests.cs | 87 +++-- ...pServerBuilderExtensionsTransportsTests.cs | 2 +- .../Server/McpServerFactoryTests.cs | 18 +- .../Server/McpServerTests.cs | 67 ++-- .../SseIntegrationTests.cs | 87 +---- .../SseServerIntegrationTests.cs | 17 +- .../Transport/SseClientTransportTests.cs | 165 ++------ .../Transport/StdioServerTransportTests.cs | 78 ++-- .../Transport/StreamClientTransport.cs | 24 +- .../Utils/TestServerTransport.cs | 21 +- 40 files changed, 1728 insertions(+), 1532 deletions(-) delete mode 100644 src/ModelContextProtocol/Hosting/McpServerHostedService.cs create mode 100644 src/ModelContextProtocol/Hosting/McpServerMultiSessionHostedService.cs create mode 100644 src/ModelContextProtocol/Hosting/McpServerSingleSessionHostedService.cs create mode 100644 src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerSessionTransport.cs create mode 100644 src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs create mode 100644 src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs create mode 100644 src/ModelContextProtocol/Shared/McpSession.cs create mode 100644 src/ModelContextProtocol/Shared/NotificationHandlers.cs create mode 100644 src/ModelContextProtocol/Shared/RequestHandlers.cs diff --git a/src/ModelContextProtocol/Client/McpClient.cs b/src/ModelContextProtocol/Client/McpClient.cs index 8ee2d1db..4ba44d72 100644 --- a/src/ModelContextProtocol/Client/McpClient.cs +++ b/src/ModelContextProtocol/Client/McpClient.cs @@ -1,11 +1,11 @@ -using ModelContextProtocol.Configuration; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Configuration; using ModelContextProtocol.Logging; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Shared; using ModelContextProtocol.Utils.Json; -using Microsoft.Extensions.Logging; using System.Text.Json; namespace ModelContextProtocol.Client; @@ -13,23 +13,25 @@ namespace ModelContextProtocol.Client; /// internal sealed class McpClient : McpJsonRpcEndpoint, IMcpClient { - private readonly McpClientOptions _options; private readonly IClientTransport _clientTransport; + private readonly McpClientOptions _options; - private int _connecting; + private ITransport? _sessionTransport; + private CancellationTokenSource? _connectCts; + private int _disposed; /// /// Initializes a new instance of the class. /// - /// The transport to use for communication with the server. + /// The transport to use for communication with the server. /// Options for the client, defining protocol version and capabilities. /// The server configuration. /// The logger factory. - public McpClient(IClientTransport transport, McpClientOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory) - : base(transport, loggerFactory) + public McpClient(IClientTransport clientTransport, McpClientOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory) + : base(loggerFactory) { + _clientTransport = clientTransport; _options = options; - _clientTransport = transport; EndpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})"; @@ -70,25 +72,19 @@ public McpClient(IClientTransport transport, McpClientOptions options, McpServer /// public override string EndpointName { get; } - /// public async Task ConnectAsync(CancellationToken cancellationToken = default) { - if (Interlocked.Exchange(ref _connecting, 1) != 0) - { - _logger.ClientAlreadyInitializing(EndpointName); - throw new InvalidOperationException("Client is already in use."); - } - - CancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - cancellationToken = CancellationTokenSource.Token; + _connectCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + cancellationToken = _connectCts.Token; try { // Connect transport - await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false); - - // Start processing messages - MessageProcessingTask = ProcessMessagesAsync(cancellationToken); + _sessionTransport = await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false); + InitializeSession(_sessionTransport); + // We don't want the ConnectAsync token to cancel the session after we've successfully connected. + // The base class handles cleaning up the session in DisposeAsync without our help. + StartSession(fullSessionCancellationToken: CancellationToken.None); // Perform initialization sequence using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); @@ -140,8 +136,27 @@ await SendMessageAsync( catch (Exception e) { _logger.ClientInitializationError(EndpointName, e); - await CleanupAsync().ConfigureAwait(false); + await DisposeAsync().ConfigureAwait(false); throw; } } + + /// + public override async ValueTask DisposeAsync() + { + if (Interlocked.Exchange(ref _disposed, 1) != 0) + { + // TODO: It's more correct to await the last DisposeAsync before returning if it's still ongoing. + return; + } + + if (_connectCts is not null) + { + await _connectCts.CancelAsync().ConfigureAwait(false); + } + + await base.DisposeAsync().ConfigureAwait(false); + + _connectCts?.Dispose(); + } } diff --git a/src/ModelContextProtocol/Client/McpClientFactory.cs b/src/ModelContextProtocol/Client/McpClientFactory.cs index 691d165a..08153faa 100644 --- a/src/ModelContextProtocol/Client/McpClientFactory.cs +++ b/src/ModelContextProtocol/Client/McpClientFactory.cs @@ -65,24 +65,16 @@ public static async Task CreateAsync( createTransportFunc(serverConfig, loggerFactory) ?? throw new InvalidOperationException($"{nameof(createTransportFunc)} returned a null transport."); + McpClient client = new(transport, clientOptions, serverConfig, loggerFactory); try { - McpClient client = new(transport, clientOptions, serverConfig, loggerFactory); - try - { - await client.ConnectAsync(cancellationToken).ConfigureAwait(false); - logger.ClientCreated(endpointName); - return client; - } - catch - { - await client.DisposeAsync().ConfigureAwait(false); - throw; - } + await client.ConnectAsync(cancellationToken).ConfigureAwait(false); + logger.ClientCreated(endpointName); + return client; } catch { - await transport.DisposeAsync().ConfigureAwait(false); + await client.DisposeAsync().ConfigureAwait(false); throw; } } diff --git a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs index 1f357d32..938d1e10 100644 --- a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs +++ b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs @@ -3,6 +3,9 @@ using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Utils; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Server; namespace ModelContextProtocol; @@ -19,8 +22,18 @@ public static IMcpServerBuilder WithStdioServerTransport(this IMcpServerBuilder { Throw.IfNull(builder); - builder.Services.AddSingleton(); - builder.Services.AddHostedService(); + builder.Services.AddSingleton(); + builder.Services.AddHostedService(); + + builder.Services.AddSingleton(services => + { + ITransport serverTransport = services.GetRequiredService(); + IOptions options = services.GetRequiredService>(); + ILoggerFactory? loggerFactory = services.GetService(); + + return McpServerFactory.Create(serverTransport, options.Value, loggerFactory, services); + }); + return builder; } @@ -33,7 +46,7 @@ public static IMcpServerBuilder WithHttpListenerSseServerTransport(this IMcpServ Throw.IfNull(builder); builder.Services.AddSingleton(); - builder.Services.AddHostedService(); + builder.Services.AddHostedService(); return builder; } } diff --git a/src/ModelContextProtocol/Configuration/McpServerServiceCollectionExtension.cs b/src/ModelContextProtocol/Configuration/McpServerServiceCollectionExtension.cs index 767c1440..55bbb4ea 100644 --- a/src/ModelContextProtocol/Configuration/McpServerServiceCollectionExtension.cs +++ b/src/ModelContextProtocol/Configuration/McpServerServiceCollectionExtension.cs @@ -8,7 +8,7 @@ namespace ModelContextProtocol; /// -/// Extension to host the MCP server +/// Extension to host an MCP server /// public static class McpServerServiceCollectionExtension { @@ -20,15 +20,6 @@ public static class McpServerServiceCollectionExtension /// public static IMcpServerBuilder AddMcpServer(this IServiceCollection services, Action? configureOptions = null) { - services.AddSingleton(services => - { - IServerTransport serverTransport = services.GetRequiredService(); - IOptions options = services.GetRequiredService>(); - ILoggerFactory? loggerFactory = services.GetService(); - - return McpServerFactory.Create(serverTransport, options.Value, loggerFactory, services); - }); - services.AddOptions(); services.AddTransient, McpServerOptionsSetup>(); if (configureOptions is not null) diff --git a/src/ModelContextProtocol/Hosting/McpServerHostedService.cs b/src/ModelContextProtocol/Hosting/McpServerHostedService.cs deleted file mode 100644 index 51c0d997..00000000 --- a/src/ModelContextProtocol/Hosting/McpServerHostedService.cs +++ /dev/null @@ -1,31 +0,0 @@ -using ModelContextProtocol.Server; -using ModelContextProtocol.Utils; -using Microsoft.Extensions.Hosting; - -namespace ModelContextProtocol.Hosting; - -/// -/// Hosted service for the MCP server. -/// -public class McpServerHostedService : BackgroundService -{ - private readonly IMcpServer _server; - - /// - /// Creates a new instance of the McpServerHostedService. - /// - /// The MCP server instance - /// - public McpServerHostedService(IMcpServer server) - { - Throw.IfNull(server); - - _server = server; - } - - /// - protected override async Task ExecuteAsync(CancellationToken stoppingToken) - { - await _server.RunAsync(cancellationToken: stoppingToken).ConfigureAwait(false); - } -} diff --git a/src/ModelContextProtocol/Hosting/McpServerMultiSessionHostedService.cs b/src/ModelContextProtocol/Hosting/McpServerMultiSessionHostedService.cs new file mode 100644 index 00000000..e465a725 --- /dev/null +++ b/src/ModelContextProtocol/Hosting/McpServerMultiSessionHostedService.cs @@ -0,0 +1,43 @@ +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Server; + +namespace ModelContextProtocol.Hosting; + +/// +/// Hosted service for a multi-session (i.e. HTTP) MCP server. +/// +internal class McpServerMultiSessionHostedService : BackgroundService +{ + private readonly IServerTransport _serverTransport; + private readonly McpServerOptions _serverOptions; + private readonly ILoggerFactory _loggerFactory; + private readonly IServiceProvider _serviceProvider; + + public McpServerMultiSessionHostedService( + IServerTransport serverTransport, + IOptions serverOptions, + ILoggerFactory loggerFactory, + IServiceProvider serviceProvider) + { + _serverTransport = serverTransport; + _serverOptions = serverOptions.Value; + _loggerFactory = loggerFactory; + _serviceProvider = serviceProvider; + } + + /// + protected override async Task ExecuteAsync(CancellationToken stoppingToken) + { + while (await AcceptSessionAsync(stoppingToken).ConfigureAwait(false) is { } server) + { + // TODO: Track all running sessions and wait for all sessions to complete for graceful shutdown. + _ = server.RunAsync(stoppingToken); + } + } + + private Task AcceptSessionAsync(CancellationToken cancellationToken) + => McpServerFactory.AcceptAsync(_serverTransport, _serverOptions, _loggerFactory, _serviceProvider, cancellationToken); +} diff --git a/src/ModelContextProtocol/Hosting/McpServerSingleSessionHostedService.cs b/src/ModelContextProtocol/Hosting/McpServerSingleSessionHostedService.cs new file mode 100644 index 00000000..a59d783f --- /dev/null +++ b/src/ModelContextProtocol/Hosting/McpServerSingleSessionHostedService.cs @@ -0,0 +1,13 @@ +using Microsoft.Extensions.Hosting; +using ModelContextProtocol.Server; + +namespace ModelContextProtocol.Hosting; + +/// +/// Hosted service for a single-session (i.e stdio) MCP server. +/// +internal class McpServerSingleSessionHostedService(IMcpServer session) : BackgroundService +{ + /// + protected override Task ExecuteAsync(CancellationToken stoppingToken) => session.RunAsync(stoppingToken); +} diff --git a/src/ModelContextProtocol/ModelContextProtocol.csproj b/src/ModelContextProtocol/ModelContextProtocol.csproj index dcee6278..6860381d 100644 --- a/src/ModelContextProtocol/ModelContextProtocol.csproj +++ b/src/ModelContextProtocol/ModelContextProtocol.csproj @@ -14,7 +14,7 @@ - + diff --git a/src/ModelContextProtocol/Protocol/Transport/HttpListenerServerProvider.cs b/src/ModelContextProtocol/Protocol/Transport/HttpListenerServerProvider.cs index dbeb4e6c..2071a716 100644 --- a/src/ModelContextProtocol/Protocol/Transport/HttpListenerServerProvider.cs +++ b/src/ModelContextProtocol/Protocol/Transport/HttpListenerServerProvider.cs @@ -44,8 +44,7 @@ public HttpListenerServerProvider(int port) public required Func OnSseConnectionAsync { get; set; } public required Func> OnMessageAsync { get; set; } - /// - public async Task StartAsync(CancellationToken cancellationToken = default) + public void Start() { if (Interlocked.CompareExchange(ref _state, StateRunning, StateNotStarted) != StateNotStarted) { @@ -60,7 +59,7 @@ public async Task StartAsync(CancellationToken cancellationToken = default) { try { - using var cts = CancellationTokenSource.CreateLinkedTokenSource(_shutdownTokenSource.Token, cancellationToken); + using var cts = CancellationTokenSource.CreateLinkedTokenSource(_shutdownTokenSource.Token); cts.Token.Register(_listener.Stop); while (!cts.IsCancellationRequested) { @@ -100,6 +99,7 @@ public async ValueTask DisposeAsync() await _shutdownTokenSource.CancelAsync().ConfigureAwait(false); _listener.Stop(); await _listeningTask.ConfigureAwait(false); + await _completed.Task.ConfigureAwait(false); } /// Gets a that completes when the server has finished its work. diff --git a/src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerSessionTransport.cs new file mode 100644 index 00000000..6b2219e5 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerSessionTransport.cs @@ -0,0 +1,76 @@ +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Logging; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Utils; +using ModelContextProtocol.Utils.Json; +using System.Net; +using System.Text.Json; + +namespace ModelContextProtocol.Protocol.Transport; + +/// +/// Implements the MCP transport protocol using . +/// +internal sealed class HttpListenerSseServerSessionTransport : TransportBase +{ + private readonly string _serverName; + private readonly ILogger _logger; + private SseResponseStreamTransport _responseStreamTransport; + + private string EndpointName => $"Server (SSE) ({_serverName})"; + + public HttpListenerSseServerSessionTransport(string serverName, SseResponseStreamTransport responseStreamTransport, ILoggerFactory loggerFactory) + : base(loggerFactory) + { + Throw.IfNull(serverName); + + _serverName = serverName; + _responseStreamTransport = responseStreamTransport; + _logger = loggerFactory.CreateLogger(); + SetConnected(true); + } + + /// + public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + { + if (!IsConnected) + { + _logger.TransportNotConnected(EndpointName); + throw new McpTransportException("Transport is not connected"); + } + + string id = "(no id)"; + if (message is IJsonRpcMessageWithId messageWithId) + { + id = messageWithId.Id.ToString(); + } + + try + { + if (_logger.IsEnabled(LogLevel.Debug)) + { + var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions.GetTypeInfo()); + _logger.TransportSendingMessage(EndpointName, id, json); + } + + await _responseStreamTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); + + _logger.TransportSentMessage(EndpointName, id); + } + catch (Exception ex) + { + _logger.TransportSendFailed(EndpointName, id, ex); + throw new McpTransportException("Failed to send message", ex); + } + } + + public Task OnMessageReceivedAsync(IJsonRpcMessage message, CancellationToken cancellationToken) + => WriteMessageAsync(message, cancellationToken); + + /// + public override ValueTask DisposeAsync() + { + SetConnected(false); + return default; + } +} diff --git a/src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerTransport.cs index 44280608..614be147 100644 --- a/src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerTransport.cs @@ -6,20 +6,23 @@ using ModelContextProtocol.Logging; using ModelContextProtocol.Server; using ModelContextProtocol.Utils; - -#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously +using System.Threading.Channels; namespace ModelContextProtocol.Protocol.Transport; /// /// Implements the MCP transport protocol using . /// -public sealed class HttpListenerSseServerTransport : TransportBase, IServerTransport +public sealed class HttpListenerSseServerTransport : IServerTransport, IAsyncDisposable { private readonly string _serverName; private readonly HttpListenerServerProvider _httpServerProvider; + private readonly ILoggerFactory _loggerFactory; private readonly ILogger _logger; - private SseResponseStreamTransport? _sseResponseStreamTransport; + + private readonly Channel _incomingSessions; + + private HttpListenerSseServerSessionTransport? _sessionTransport; private string EndpointName => $"Server (SSE) ({_serverName})"; @@ -41,77 +44,69 @@ public HttpListenerSseServerTransport(McpServerOptions serverOptions, int port, /// The port to listen on. /// A logger factory for creating loggers. public HttpListenerSseServerTransport(string serverName, int port, ILoggerFactory loggerFactory) - : base(loggerFactory) { Throw.IfNull(serverName); _serverName = serverName; + _loggerFactory = loggerFactory; _logger = loggerFactory.CreateLogger(); _httpServerProvider = new HttpListenerServerProvider(port) { OnSseConnectionAsync = OnSseConnectionAsync, OnMessageAsync = OnMessageAsync, }; - } - /// - public Task StartListeningAsync(CancellationToken cancellationToken = default) => - _httpServerProvider.StartAsync(cancellationToken); - - /// - public Task Completion => _httpServerProvider.Completed; + // Until we support session IDs, there's no way to support more than one concurrent session. + // Any new SSE connection overwrites the old session and any new /messages go to the new session. + _incomingSessions = Channel.CreateBounded(new BoundedChannelOptions(1) + { + FullMode = BoundedChannelFullMode.DropOldest, + }); + + // REVIEW: We could add another layer of async for binding similar to Kestrel's IConnectionListenerFactory, + // but this wouldn't play well with a static factory method to accept new sessions. Ultimately, + // ASP.NET Core is not going to hand over binding to the MCP SDK, so I decided to just bind in the transport + // constructor for now. + _httpServerProvider.Start(); + } /// - public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + public async Task AcceptAsync(CancellationToken cancellationToken = default) { - if (!IsConnected || _sseResponseStreamTransport is null) - { - _logger.TransportNotConnected(EndpointName); - throw new McpTransportException("Transport is not connected"); - } - - string id = "(no id)"; - if (message is IJsonRpcMessageWithId messageWithId) - { - id = messageWithId.Id.ToString(); - } - - try + while (await _incomingSessions.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) { - if (_logger.IsEnabled(LogLevel.Debug)) + if (_incomingSessions.Reader.TryRead(out var session)) { - var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions.GetTypeInfo()); - _logger.TransportSendingMessage(EndpointName, id, json); + return session; } - - await _sseResponseStreamTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); - - _logger.TransportSentMessage(EndpointName, id); - } - catch (Exception ex) - { - _logger.TransportSendFailed(EndpointName, id, ex); - throw new McpTransportException("Failed to send message", ex); } + + return null; } /// - public override async ValueTask DisposeAsync() + public async ValueTask DisposeAsync() { _logger.TransportCleaningUp(EndpointName); - SetConnected(false); await _httpServerProvider.DisposeAsync().ConfigureAwait(false); + _incomingSessions.Writer.TryComplete(); _logger.TransportCleanedUp(EndpointName); } private async Task OnSseConnectionAsync(Stream responseStream, CancellationToken cancellationToken) { - await using var sseResponseStreamTransport = new SseResponseStreamTransport(responseStream); - _sseResponseStreamTransport = sseResponseStreamTransport; - SetConnected(true); - await sseResponseStreamTransport.RunAsync(cancellationToken); + var sseResponseStreamTransport = new SseResponseStreamTransport(responseStream); + var sessionTransport = new HttpListenerSseServerSessionTransport(_serverName, sseResponseStreamTransport, _loggerFactory); + + await using (sseResponseStreamTransport.ConfigureAwait(false)) + await using (sseResponseStreamTransport.ConfigureAwait(false)) + { + _sessionTransport = sessionTransport; + await _incomingSessions.Writer.WriteAsync(sessionTransport).ConfigureAwait(false); + await sseResponseStreamTransport.RunAsync(cancellationToken).ConfigureAwait(false); + } } /// @@ -138,7 +133,7 @@ private async Task OnMessageAsync(Stream requestStream, CancellationToken try { - message ??= await JsonSerializer.DeserializeAsync(requestStream, McpJsonUtilities.DefaultOptions.GetTypeInfo()); + message ??= await JsonSerializer.DeserializeAsync(requestStream, McpJsonUtilities.DefaultOptions.GetTypeInfo()).ConfigureAwait(false); if (message != null) { string messageId = "(no id)"; @@ -148,7 +143,14 @@ private async Task OnMessageAsync(Stream requestStream, CancellationToken } _logger.TransportReceivedMessageParsed(EndpointName, messageId); - await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); + + if (_sessionTransport is null) + { + return false; + } + + await _sessionTransport.OnMessageReceivedAsync(message, cancellationToken).ConfigureAwait(false); + _logger.TransportMessageWritten(EndpointName, messageId); return true; diff --git a/src/ModelContextProtocol/Protocol/Transport/IClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/IClientTransport.cs index b7ad7f27..c68f6cef 100644 --- a/src/ModelContextProtocol/Protocol/Transport/IClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/IClientTransport.cs @@ -3,11 +3,12 @@ /// /// Represents a transport mechanism for MCP communication (from the client). /// -public interface IClientTransport : ITransport +public interface IClientTransport { /// - /// Establishes the transport connection. + /// Asynchronously establishes a transport session with an MCP server and returns an interface for the duplex JSON-RPC message stream. /// /// Token to cancel the operation. - Task ConnectAsync(CancellationToken cancellationToken = default); + /// Returns an interface for the duplex JSON-RPC message stream. + Task ConnectAsync(CancellationToken cancellationToken = default); } diff --git a/src/ModelContextProtocol/Protocol/Transport/IServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/IServerTransport.cs index e204202e..0d1a9774 100644 --- a/src/ModelContextProtocol/Protocol/Transport/IServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/IServerTransport.cs @@ -3,14 +3,12 @@ /// /// Represents a transport mechanism for MCP communication (from the server). /// -public interface IServerTransport : ITransport +public interface IServerTransport { /// - /// Starts listening for incoming messages. + /// Asynchronously accepts a transport session initiated by an MCP client and returns an interface for the duplex JSON-RPC message stream. /// - /// Token to cancel the operation. - Task StartListeningAsync(CancellationToken cancellationToken = default); - - /// Gets a that will complete when the server transport has completed all work. - Task Completion { get; } + /// Used to signal the cancellation of the asynchronous operation. + /// Returns an interface for the duplex JSON-RPC message stream. + Task AcceptAsync(CancellationToken cancellationToken = default); } diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs new file mode 100644 index 00000000..c3196b08 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs @@ -0,0 +1,332 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Configuration; +using ModelContextProtocol.Logging; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Utils; +using ModelContextProtocol.Utils.Json; +using System.Net.Http.Headers; +using System.Net.ServerSentEvents; +using System.Text; +using System.Text.Json; + +namespace ModelContextProtocol.Protocol.Transport; + +/// +/// The ServerSideEvents client transport implementation +/// +internal sealed class SseClientSessionTransport : TransportBase +{ + private readonly HttpClient _httpClient; + private readonly SseClientTransportOptions _options; + private readonly Uri _sseEndpoint; + private Uri? _messageEndpoint; + private readonly CancellationTokenSource _connectionCts; + private Task? _receiveTask; + private readonly ILogger _logger; + private readonly McpServerConfig _serverConfig; + private readonly JsonSerializerOptions _jsonOptions; + private readonly TaskCompletionSource _connectionEstablished; + private readonly bool _ownsHttpClient; + + private string EndpointName => $"Client (SSE) for ({_serverConfig.Id}: {_serverConfig.Name})"; + + /// + /// SSE transport for client endpoints. Unlike stdio it does not launch a process, but connects to an existing server. + /// The HTTP server can be local or remote, and must support the SSE protocol. + /// + /// Configuration options for the transport. + /// The configuration object indicating which server to connect to. + /// The HTTP client instance used for requests. + /// Logger factory for creating loggers. + /// True to dispose HTTP client on close connection. + public SseClientSessionTransport(SseClientTransportOptions transportOptions, McpServerConfig serverConfig, HttpClient httpClient, ILoggerFactory? loggerFactory, bool ownsHttpClient = false) + : base(loggerFactory) + { + Throw.IfNull(transportOptions); + Throw.IfNull(serverConfig); + Throw.IfNull(httpClient); + + _options = transportOptions; + _serverConfig = serverConfig; + _sseEndpoint = new Uri(serverConfig.Location!); + _httpClient = httpClient; + _connectionCts = new CancellationTokenSource(); + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; + _jsonOptions = McpJsonUtilities.DefaultOptions; + _connectionEstablished = new TaskCompletionSource(); + _ownsHttpClient = ownsHttpClient; + } + + /// + public async Task ConnectAsync(CancellationToken cancellationToken = default) + { + try + { + if (IsConnected) + { + _logger.TransportAlreadyConnected(EndpointName); + throw new McpTransportException("Transport is already connected"); + } + + // Start message receiving loop + _receiveTask = ReceiveMessagesAsync(_connectionCts.Token); + + _logger.TransportReadingMessages(EndpointName); + + await _connectionEstablished.Task.WaitAsync(_options.ConnectionTimeout, cancellationToken).ConfigureAwait(false); + } + catch (McpTransportException) + { + // Rethrow transport exceptions + throw; + } + catch (Exception ex) + { + _logger.TransportConnectFailed(EndpointName, ex); + await CloseAsync().ConfigureAwait(false); + throw new McpTransportException("Failed to connect transport", ex); + } + } + + /// + public override async Task SendMessageAsync( + IJsonRpcMessage message, + CancellationToken cancellationToken = default) + { + if (_messageEndpoint == null) + throw new InvalidOperationException("Transport not connected"); + + using var content = new StringContent( + JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo()), + Encoding.UTF8, + "application/json" + ); + + string messageId = "(no id)"; + + if (message is IJsonRpcMessageWithId messageWithId) + { + messageId = messageWithId.Id.ToString(); + } + + var response = await _httpClient.PostAsync( + _messageEndpoint, + content, + cancellationToken + ).ConfigureAwait(false); + + response.EnsureSuccessStatusCode(); + + var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); + + // Check if the message was an initialize request + if (message is JsonRpcRequest request && request.Method == "initialize") + { + // If the response is not a JSON-RPC response, it is an SSE message + if (responseContent.Equals("accepted", StringComparison.OrdinalIgnoreCase)) + { + _logger.SSETransportPostAccepted(EndpointName, messageId); + // The response will arrive as an SSE message + } + else + { + JsonRpcResponse initializeResponse = JsonSerializer.Deserialize(responseContent, _jsonOptions.GetTypeInfo()) ?? + throw new McpTransportException("Failed to initialize client"); + + _logger.TransportReceivedMessageParsed(EndpointName, messageId); + await WriteMessageAsync(initializeResponse, cancellationToken).ConfigureAwait(false); + _logger.TransportMessageWritten(EndpointName, messageId); + } + return; + } + + // Otherwise, check if the response was accepted (the response will come as an SSE message) + if (responseContent.Equals("accepted", StringComparison.OrdinalIgnoreCase)) + { + _logger.SSETransportPostAccepted(EndpointName, messageId); + } + else + { + _logger.SSETransportPostNotAccepted(EndpointName, messageId, responseContent); + throw new McpTransportException("Failed to send message"); + } + } + + private async Task CloseAsync() + { + try + { + if (!_connectionCts.IsCancellationRequested) + { + await _connectionCts.CancelAsync().ConfigureAwait(false); + _connectionCts.Dispose(); + } + + if (_receiveTask != null) + { + await _receiveTask.ConfigureAwait(false); + } + } + finally + { + SetConnected(false); + } + } + + /// + public override async ValueTask DisposeAsync() + { + try + { + await CloseAsync(); + + if (_ownsHttpClient) + { + _httpClient?.Dispose(); + } + } + catch (Exception) + { + // Ignore exceptions on close + } + } + + internal Uri? MessageEndpoint => _messageEndpoint; + + internal SseClientTransportOptions Options => _options; + + private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) + { + int reconnectAttempts = 0; + + while (!cancellationToken.IsCancellationRequested && !IsConnected) + { + try + { + using var request = new HttpRequestMessage(HttpMethod.Get, _sseEndpoint); + request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); + + using var response = await _httpClient.SendAsync( + request, + HttpCompletionOption.ResponseHeadersRead, + cancellationToken + ).ConfigureAwait(false); + + response.EnsureSuccessStatusCode(); + + using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + + await foreach (SseItem sseEvent in SseParser.Create(stream).EnumerateAsync(cancellationToken).ConfigureAwait(false)) + { + switch (sseEvent.EventType) + { + case "endpoint": + HandleEndpointEvent(sseEvent.Data); + break; + + case "message": + await ProcessSseMessage(sseEvent.Data, cancellationToken).ConfigureAwait(false); + break; + } + } + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + _logger.TransportReadMessagesCancelled(EndpointName); + // Normal shutdown + } + catch (IOException) when (cancellationToken.IsCancellationRequested) + { + _logger.TransportReadMessagesCancelled(EndpointName); + // Normal shutdown + } + catch (Exception ex) when (!cancellationToken.IsCancellationRequested) + { + _logger.TransportConnectionError(EndpointName, ex); + + reconnectAttempts++; + if (reconnectAttempts >= _options.MaxReconnectAttempts) + { + throw new McpTransportException("Exceeded reconnect limit", ex); + } + + await Task.Delay(_options.ReconnectDelay, cancellationToken).ConfigureAwait(false); + } + } + + SetConnected(false); + } + + private async Task ProcessSseMessage(string data, CancellationToken cancellationToken) + { + if (!IsConnected) + { + _logger.TransportMessageReceivedBeforeConnected(EndpointName, data); + return; + } + + try + { + var message = JsonSerializer.Deserialize(data, _jsonOptions.GetTypeInfo()); + if (message == null) + { + _logger.TransportMessageParseUnexpectedType(EndpointName, data); + return; + } + + string messageId = "(no id)"; + if (message is IJsonRpcMessageWithId messageWithId) + { + messageId = messageWithId.Id.ToString(); + } + + _logger.TransportReceivedMessageParsed(EndpointName, messageId); + await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); + _logger.TransportMessageWritten(EndpointName, messageId); + } + catch (JsonException ex) + { + _logger.TransportMessageParseFailed(EndpointName, data, ex); + } + } + + private void HandleEndpointEvent(string data) + { + try + { + if (string.IsNullOrEmpty(data)) + { + _logger.TransportEndpointEventInvalid(EndpointName, data); + return; + } + + // Check if data is absolute URI + if (data.StartsWith("http://", StringComparison.OrdinalIgnoreCase) || data.StartsWith("https://", StringComparison.OrdinalIgnoreCase)) + { + // Since the endpoint is an absolute URI, we can use it directly + _messageEndpoint = new Uri(data); + } + else + { + // If the endpoint is a relative URI, we need to combine it with the relative path of the SSE endpoint + var hostUrl = _sseEndpoint.AbsoluteUri; + if (hostUrl.EndsWith("/sse", StringComparison.Ordinal)) + hostUrl = hostUrl[..^4]; + + var endpointUri = $"{hostUrl.TrimEnd('/')}/{data.TrimStart('/')}"; + + _messageEndpoint = new Uri(endpointUri); + } + + // Set connected state + SetConnected(true); + _connectionEstablished.TrySetResult(true); + } + catch (JsonException ex) + { + _logger.TransportEndpointEventParseFailed(EndpointName, data, ex); + throw new McpTransportException("Failed to parse endpoint event", ex); + } + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs index 654b3a6b..79db9a1d 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs @@ -1,36 +1,20 @@ -using System.Net.Http.Headers; -using System.Net.ServerSentEvents; -using System.Text; -using System.Text.Json; +using Microsoft.Extensions.Logging; using ModelContextProtocol.Configuration; -using ModelContextProtocol.Logging; -using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Utils; -using ModelContextProtocol.Utils.Json; -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; namespace ModelContextProtocol.Protocol.Transport; /// /// The ServerSideEvents client transport implementation /// -public sealed class SseClientTransport : TransportBase, IClientTransport +public sealed class SseClientTransport : IClientTransport { - private readonly HttpClient _httpClient; private readonly SseClientTransportOptions _options; - private readonly Uri _sseEndpoint; - private Uri? _messageEndpoint; - private readonly CancellationTokenSource _connectionCts; - private Task? _receiveTask; - private readonly ILogger _logger; private readonly McpServerConfig _serverConfig; - private readonly JsonSerializerOptions _jsonOptions; - private readonly TaskCompletionSource _connectionEstablished; + private readonly HttpClient _httpClient; + private readonly ILoggerFactory? _loggerFactory; private readonly bool _ownsHttpClient; - private string EndpointName => $"Client (SSE) for ({_serverConfig.Id}: {_serverConfig.Name})"; - /// /// SSE transport for client endpoints. Unlike stdio it does not launch a process, but connects to an existing server. /// The HTTP server can be local or remote, and must support the SSE protocol. @@ -53,7 +37,6 @@ public SseClientTransport(SseClientTransportOptions transportOptions, McpServerC /// Logger factory for creating loggers. /// True to dispose HTTP client on close connection. public SseClientTransport(SseClientTransportOptions transportOptions, McpServerConfig serverConfig, HttpClient httpClient, ILoggerFactory? loggerFactory, bool ownsHttpClient = false) - : base(loggerFactory) { Throw.IfNull(transportOptions); Throw.IfNull(serverConfig); @@ -61,285 +44,25 @@ public SseClientTransport(SseClientTransportOptions transportOptions, McpServerC _options = transportOptions; _serverConfig = serverConfig; - _sseEndpoint = new Uri(serverConfig.Location!); _httpClient = httpClient; - _connectionCts = new CancellationTokenSource(); - _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; - _jsonOptions = McpJsonUtilities.DefaultOptions; - _connectionEstablished = new TaskCompletionSource(); + _loggerFactory = loggerFactory; _ownsHttpClient = ownsHttpClient; } - /// - public async Task ConnectAsync(CancellationToken cancellationToken = default) - { - try - { - if (IsConnected) - { - _logger.TransportAlreadyConnected(EndpointName); - throw new McpTransportException("Transport is already connected"); - } - - // Start message receiving loop - _receiveTask = ReceiveMessagesAsync(_connectionCts.Token); - - _logger.TransportReadingMessages(EndpointName); - - await _connectionEstablished.Task.WaitAsync(_options.ConnectionTimeout, cancellationToken).ConfigureAwait(false); - } - catch (McpTransportException) - { - // Rethrow transport exceptions - throw; - } - catch (Exception ex) - { - _logger.TransportConnectFailed(EndpointName, ex); - await CloseAsync().ConfigureAwait(false); - throw new McpTransportException("Failed to connect transport", ex); - } - } - - /// - public override async Task SendMessageAsync( - IJsonRpcMessage message, - CancellationToken cancellationToken = default) - { - if (_messageEndpoint == null) - throw new InvalidOperationException("Transport not connected"); - - using var content = new StringContent( - JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo()), - Encoding.UTF8, - "application/json" - ); - - string messageId = "(no id)"; - - if (message is IJsonRpcMessageWithId messageWithId) - { - messageId = messageWithId.Id.ToString(); - } - - var response = await _httpClient.PostAsync( - _messageEndpoint, - content, - cancellationToken - ).ConfigureAwait(false); - - response.EnsureSuccessStatusCode(); - - var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); - - // Check if the message was an initialize request - if (message is JsonRpcRequest request && request.Method == "initialize") - { - // If the response is not a JSON-RPC response, it is an SSE message - if (responseContent.Equals("accepted", StringComparison.OrdinalIgnoreCase)) - { - _logger.SSETransportPostAccepted(EndpointName, messageId); - // The response will arrive as an SSE message - } - else - { - JsonRpcResponse initializeResponse = JsonSerializer.Deserialize(responseContent, _jsonOptions.GetTypeInfo()) ?? - throw new McpTransportException("Failed to initialize client"); - - _logger.TransportReceivedMessageParsed(EndpointName, messageId); - await WriteMessageAsync(initializeResponse, cancellationToken).ConfigureAwait(false); - _logger.TransportMessageWritten(EndpointName, messageId); - } - return; - } - - // Otherwise, check if the response was accepted (the response will come as an SSE message) - if (responseContent.Equals("accepted", StringComparison.OrdinalIgnoreCase)) - { - _logger.SSETransportPostAccepted(EndpointName, messageId); - } - else - { - _logger.SSETransportPostNotAccepted(EndpointName, messageId, responseContent); - throw new McpTransportException("Failed to send message"); - } - } - - private async Task CloseAsync() - { - try - { - if (!_connectionCts.IsCancellationRequested) - { - await _connectionCts.CancelAsync().ConfigureAwait(false); - _connectionCts.Dispose(); - } - - if (_receiveTask != null) - { - await _receiveTask.ConfigureAwait(false); - } - } - finally - { - SetConnected(false); - } - } - - /// - public override async ValueTask DisposeAsync() - { - try - { - if (_ownsHttpClient) - { - _httpClient?.Dispose(); - } - - await CloseAsync(); - } - catch (Exception) - { - // Ignore exceptions on close - } - } - - internal Uri? MessageEndpoint => _messageEndpoint; - - internal SseClientTransportOptions Options => _options; - - private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) - { - int reconnectAttempts = 0; - - while (!cancellationToken.IsCancellationRequested) - { - try - { - using var request = new HttpRequestMessage(HttpMethod.Get, _sseEndpoint); - request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); - - using var response = await _httpClient.SendAsync( - request, - HttpCompletionOption.ResponseHeadersRead, - cancellationToken - ).ConfigureAwait(false); - - response.EnsureSuccessStatusCode(); - - using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); - - // Reset reconnect attempts on successful connection - reconnectAttempts = 0; - - await foreach (SseItem sseEvent in SseParser.Create(stream).EnumerateAsync(cancellationToken).ConfigureAwait(false)) - { - switch (sseEvent.EventType) - { - case "endpoint": - HandleEndpointEvent(sseEvent.Data); - break; - - case "message": - await ProcessSseMessage(sseEvent.Data, cancellationToken).ConfigureAwait(false); - break; - } - } - } - catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) - { - _logger.TransportReadMessagesCancelled(EndpointName); - // Normal shutdown - } - catch (IOException) when (cancellationToken.IsCancellationRequested) - { - _logger.TransportReadMessagesCancelled(EndpointName); - // Normal shutdown - } - catch (Exception ex) when (!cancellationToken.IsCancellationRequested) - { - _logger.TransportConnectionError(EndpointName, ex); - - reconnectAttempts++; - if (reconnectAttempts >= _options.MaxReconnectAttempts) - { - throw new McpTransportException("Exceeded reconnect limit", ex); - } - - await Task.Delay(_options.ReconnectDelay, cancellationToken).ConfigureAwait(false); - } - } - } - - private async Task ProcessSseMessage(string data, CancellationToken cancellationToken) + /// + public async Task ConnectAsync(CancellationToken cancellationToken = default) { - if (!IsConnected) - { - _logger.TransportMessageReceivedBeforeConnected(EndpointName, data); - return; - } - - try - { - var message = JsonSerializer.Deserialize(data, _jsonOptions.GetTypeInfo()); - if (message == null) - { - _logger.TransportMessageParseUnexpectedType(EndpointName, data); - return; - } - - string messageId = "(no id)"; - if (message is IJsonRpcMessageWithId messageWithId) - { - messageId = messageWithId.Id.ToString(); - } + var sessionTransport = new SseClientSessionTransport(_options, _serverConfig, _httpClient, _loggerFactory, _ownsHttpClient); - _logger.TransportReceivedMessageParsed(EndpointName, messageId); - await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); - _logger.TransportMessageWritten(EndpointName, messageId); - } - catch (JsonException ex) - { - _logger.TransportMessageParseFailed(EndpointName, data, ex); - } - } - - private void HandleEndpointEvent(string data) - { try { - if (string.IsNullOrEmpty(data)) - { - _logger.TransportEndpointEventInvalid(EndpointName, data); - return; - } - - // Check if data is absolute URI - if (data.StartsWith("http://", StringComparison.OrdinalIgnoreCase) || data.StartsWith("https://", StringComparison.OrdinalIgnoreCase)) - { - // Since the endpoint is an absolute URI, we can use it directly - _messageEndpoint = new Uri(data); - } - else - { - // If the endpoint is a relative URI, we need to combine it with the relative path of the SSE endpoint - var hostUrl = _sseEndpoint.AbsoluteUri; - if (hostUrl.EndsWith("/sse", StringComparison.Ordinal)) - hostUrl = hostUrl[..^4]; - - var endpointUri = $"{hostUrl.TrimEnd('/')}/{data.TrimStart('/')}"; - - _messageEndpoint = new Uri(endpointUri); - } - - // Set connected state - SetConnected(true); - _connectionEstablished.TrySetResult(true); + await sessionTransport.ConnectAsync(cancellationToken).ConfigureAwait(false); + return sessionTransport; } - catch (JsonException ex) + catch { - _logger.TransportEndpointEventParseFailed(EndpointName, data, ex); - throw new McpTransportException("Failed to parse endpoint event", ex); + await sessionTransport.DisposeAsync().ConfigureAwait(false); + throw; } } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs new file mode 100644 index 00000000..8de9aa76 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs @@ -0,0 +1,321 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Configuration; +using ModelContextProtocol.Logging; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Utils; +using ModelContextProtocol.Utils.Json; +using System.Diagnostics; +using System.Text; +using System.Text.Json; + +namespace ModelContextProtocol.Protocol.Transport; + +/// +/// Implements the MCP transport protocol over standard input/output streams. +/// +internal sealed class StdioClientStreamTransport : TransportBase +{ + private readonly StdioClientTransportOptions _options; + private readonly McpServerConfig _serverConfig; + private readonly ILogger _logger; + private readonly JsonSerializerOptions _jsonOptions; + private Process? _process; + private Task? _readTask; + private CancellationTokenSource? _shutdownCts; + private bool _processStarted; + + private string EndpointName => $"Client (stdio) for ({_serverConfig.Id}: {_serverConfig.Name})"; + + /// + /// Initializes a new instance of the StdioTransport class. + /// + /// Configuration options for the transport. + /// The server configuration for the transport. + /// A logger factory for creating loggers. + public StdioClientStreamTransport(StdioClientTransportOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory = null) + : base(loggerFactory) + { + Throw.IfNull(options); + Throw.IfNull(serverConfig); + + _options = options; + _serverConfig = serverConfig; + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; + _jsonOptions = McpJsonUtilities.DefaultOptions; + } + + /// + public async Task ConnectAsync(CancellationToken cancellationToken = default) + { + if (IsConnected) + { + _logger.TransportAlreadyConnected(EndpointName); + throw new McpTransportException("Transport is already connected"); + } + + try + { + _logger.TransportConnecting(EndpointName); + + _shutdownCts = new CancellationTokenSource(); + + UTF8Encoding noBomUTF8 = new(encoderShouldEmitUTF8Identifier: false); + + var startInfo = new ProcessStartInfo + { + FileName = _options.Command, + RedirectStandardInput = true, + RedirectStandardOutput = true, + RedirectStandardError = true, + UseShellExecute = false, + CreateNoWindow = true, + WorkingDirectory = _options.WorkingDirectory ?? Environment.CurrentDirectory, + StandardOutputEncoding = noBomUTF8, + StandardErrorEncoding = noBomUTF8, +#if NET + StandardInputEncoding = noBomUTF8, +#endif + }; + + if (!string.IsNullOrWhiteSpace(_options.Arguments)) + { + startInfo.Arguments = _options.Arguments; + } + + if (_options.EnvironmentVariables != null) + { + foreach (var entry in _options.EnvironmentVariables) + { + startInfo.Environment[entry.Key] = entry.Value; + } + } + + _logger.CreateProcessForTransport(EndpointName, _options.Command, + startInfo.Arguments, string.Join(", ", startInfo.Environment.Select(kvp => kvp.Key + "=" + kvp.Value)), + startInfo.WorkingDirectory, _options.ShutdownTimeout.ToString()); + + _process = new Process { StartInfo = startInfo }; + + // Set up error logging + _process.ErrorDataReceived += (sender, args) => _logger.TransportError(EndpointName, args.Data ?? "(no data)"); + + // We need both stdin and stdout to use a no-BOM UTF-8 encoding. On .NET Core, + // we can use ProcessStartInfo.StandardOutputEncoding/StandardInputEncoding, but + // StandardInputEncoding doesn't exist on .NET Framework; instead, it always picks + // up the encoding from Console.InputEncoding. As such, when not targeting .NET Core, + // we temporarily change Console.InputEncoding to no-BOM UTF-8 around the Process.Start + // call, to ensure it picks up the correct encoding. +#if NET + _processStarted = _process.Start(); +#else + Encoding originalInputEncoding = Console.InputEncoding; + try + { + Console.InputEncoding = noBomUTF8; + _processStarted = _process.Start(); + } + finally + { + Console.InputEncoding = originalInputEncoding; + } +#endif + + if (!_processStarted) + { + _logger.TransportProcessStartFailed(EndpointName); + throw new McpTransportException("Failed to start MCP server process"); + } + + _logger.TransportProcessStarted(EndpointName, _process.Id); + + _process.BeginErrorReadLine(); + + // Start reading messages in the background + _readTask = Task.Run(() => ReadMessagesAsync(_shutdownCts.Token), CancellationToken.None); + _logger.TransportReadingMessages(EndpointName); + + SetConnected(true); + } + catch (Exception ex) + { + _logger.TransportConnectFailed(EndpointName, ex); + await CleanupAsync(cancellationToken).ConfigureAwait(false); + throw new McpTransportException("Failed to connect transport", ex); + } + } + + /// + public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + { + if (!IsConnected || _process?.HasExited == true) + { + _logger.TransportNotConnected(EndpointName); + throw new McpTransportException("Transport is not connected"); + } + + string id = "(no id)"; + if (message is IJsonRpcMessageWithId messageWithId) + { + id = messageWithId.Id.ToString(); + } + + try + { + var json = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo()); + _logger.TransportSendingMessage(EndpointName, id, json); + _logger.TransportMessageBytesUtf8(EndpointName, json); + + // Write the message followed by a newline using our UTF-8 writer + await _process!.StandardInput.WriteLineAsync(json).ConfigureAwait(false); + await _process.StandardInput.FlushAsync(cancellationToken).ConfigureAwait(false); + + _logger.TransportSentMessage(EndpointName, id); + } + catch (Exception ex) + { + _logger.TransportSendFailed(EndpointName, id, ex); + throw new McpTransportException("Failed to send message", ex); + } + } + + /// + public override async ValueTask DisposeAsync() + { + await CleanupAsync(CancellationToken.None).ConfigureAwait(false); + } + + private async Task ReadMessagesAsync(CancellationToken cancellationToken) + { + try + { + _logger.TransportEnteringReadMessagesLoop(EndpointName); + + while (!cancellationToken.IsCancellationRequested && !_process!.HasExited) + { + _logger.TransportWaitingForMessage(EndpointName); + var line = await _process.StandardOutput.ReadLineAsync(cancellationToken).ConfigureAwait(false); + if (line == null) + { + _logger.TransportEndOfStream(EndpointName); + break; + } + + if (string.IsNullOrWhiteSpace(line)) + { + continue; + } + + _logger.TransportReceivedMessage(EndpointName, line); + _logger.TransportMessageBytesUtf8(EndpointName, line); + + await ProcessMessageAsync(line, cancellationToken).ConfigureAwait(false); + } + _logger.TransportExitingReadMessagesLoop(EndpointName); + } + catch (OperationCanceledException) + { + _logger.TransportReadMessagesCancelled(EndpointName); + // Normal shutdown + } + catch (Exception ex) + { + _logger.TransportReadMessagesFailed(EndpointName, ex); + } + finally + { + await CleanupAsync(cancellationToken).ConfigureAwait(false); + } + } + + private async Task ProcessMessageAsync(string line, CancellationToken cancellationToken) + { + try + { + line=line.Trim();//Fixes an error when the service prefixes nonprintable characters + var message = JsonSerializer.Deserialize(line, _jsonOptions.GetTypeInfo()); + if (message != null) + { + string messageId = "(no id)"; + if (message is IJsonRpcMessageWithId messageWithId) + { + messageId = messageWithId.Id.ToString(); + } + _logger.TransportReceivedMessageParsed(EndpointName, messageId); + await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); + _logger.TransportMessageWritten(EndpointName, messageId); + } + else + { + _logger.TransportMessageParseUnexpectedType(EndpointName, line); + } + } + catch (JsonException ex) + { + _logger.TransportMessageParseFailed(EndpointName, line, ex); + } + } + + private async Task CleanupAsync(CancellationToken cancellationToken) + { + _logger.TransportCleaningUp(EndpointName); + + if (_process is Process process && _processStarted && !process.HasExited) + { + try + { + // Wait for the process to exit + _logger.TransportWaitingForShutdown(EndpointName); + + // Kill the while process tree because the process may spawn child processes + // and Node.js does not kill its children when it exits properly + process.KillTree(_options.ShutdownTimeout); + } + catch (Exception ex) + { + _logger.TransportShutdownFailed(EndpointName, ex); + } + finally + { + process.Dispose(); + _process = null; + } + } + + if (_shutdownCts is { } shutdownCts) + { + await shutdownCts.CancelAsync().ConfigureAwait(false); + shutdownCts.Dispose(); + _shutdownCts = null; + } + + if (_readTask is Task readTask) + { + try + { + _logger.TransportWaitingForReadTask(EndpointName); + await readTask.WaitAsync(TimeSpan.FromSeconds(5), cancellationToken).ConfigureAwait(false); + } + catch (TimeoutException) + { + _logger.TransportCleanupReadTaskTimeout(EndpointName); + } + catch (OperationCanceledException) + { + _logger.TransportCleanupReadTaskCancelled(EndpointName); + } + catch (Exception ex) + { + _logger.TransportCleanupReadTaskFailed(EndpointName, ex); + } + finally + { + _logger.TransportReadTaskCleanedUp(EndpointName); + _readTask = null; + } + } + + SetConnected(false); + _logger.TransportCleanedUp(EndpointName); + } +} diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs index 00f4dfb4..e1c2ed2d 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs @@ -1,31 +1,17 @@ using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Configuration; -using ModelContextProtocol.Logging; -using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Utils; -using ModelContextProtocol.Utils.Json; -using System.Diagnostics; -using System.Text; -using System.Text.Json; namespace ModelContextProtocol.Protocol.Transport; /// /// Implements the MCP transport protocol over standard input/output streams. /// -public sealed class StdioClientTransport : TransportBase, IClientTransport +public sealed class StdioClientTransport : IClientTransport { private readonly StdioClientTransportOptions _options; private readonly McpServerConfig _serverConfig; - private readonly ILogger _logger; - private readonly JsonSerializerOptions _jsonOptions; - private Process? _process; - private Task? _readTask; - private CancellationTokenSource? _shutdownCts; - private bool _processStarted; - - private string EndpointName => $"Client (stdio) for ({_serverConfig.Id}: {_serverConfig.Name})"; + private readonly ILoggerFactory? _loggerFactory; /// /// Initializes a new instance of the StdioTransport class. @@ -34,289 +20,29 @@ public sealed class StdioClientTransport : TransportBase, IClientTransport /// The server configuration for the transport. /// A logger factory for creating loggers. public StdioClientTransport(StdioClientTransportOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory = null) - : base(loggerFactory) { Throw.IfNull(options); Throw.IfNull(serverConfig); _options = options; _serverConfig = serverConfig; - _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; - _jsonOptions = McpJsonUtilities.DefaultOptions; + _loggerFactory = loggerFactory; } - /// - public async Task ConnectAsync(CancellationToken cancellationToken = default) + /// + public async Task ConnectAsync(CancellationToken cancellationToken = default) { - if (IsConnected) - { - _logger.TransportAlreadyConnected(EndpointName); - throw new McpTransportException("Transport is already connected"); - } + var streamTransport = new StdioClientStreamTransport(_options, _serverConfig, _loggerFactory); try { - _logger.TransportConnecting(EndpointName); - - _shutdownCts = new CancellationTokenSource(); - - UTF8Encoding noBomUTF8 = new(encoderShouldEmitUTF8Identifier: false); - - var startInfo = new ProcessStartInfo - { - FileName = _options.Command, - RedirectStandardInput = true, - RedirectStandardOutput = true, - RedirectStandardError = true, - UseShellExecute = false, - CreateNoWindow = true, - WorkingDirectory = _options.WorkingDirectory ?? Environment.CurrentDirectory, - StandardOutputEncoding = noBomUTF8, - StandardErrorEncoding = noBomUTF8, -#if NET - StandardInputEncoding = noBomUTF8, -#endif - }; - - if (!string.IsNullOrWhiteSpace(_options.Arguments)) - { - startInfo.Arguments = _options.Arguments; - } - - if (_options.EnvironmentVariables != null) - { - foreach (var entry in _options.EnvironmentVariables) - { - startInfo.Environment[entry.Key] = entry.Value; - } - } - - _logger.CreateProcessForTransport(EndpointName, _options.Command, - startInfo.Arguments, string.Join(", ", startInfo.Environment.Select(kvp => kvp.Key + "=" + kvp.Value)), - startInfo.WorkingDirectory, _options.ShutdownTimeout.ToString()); - - _process = new Process { StartInfo = startInfo }; - - // Set up error logging - _process.ErrorDataReceived += (sender, args) => _logger.TransportError(EndpointName, args.Data ?? "(no data)"); - - // We need both stdin and stdout to use a no-BOM UTF-8 encoding. On .NET Core, - // we can use ProcessStartInfo.StandardOutputEncoding/StandardInputEncoding, but - // StandardInputEncoding doesn't exist on .NET Framework; instead, it always picks - // up the encoding from Console.InputEncoding. As such, when not targeting .NET Core, - // we temporarily change Console.InputEncoding to no-BOM UTF-8 around the Process.Start - // call, to ensure it picks up the correct encoding. -#if NET - _processStarted = _process.Start(); -#else - Encoding originalInputEncoding = Console.InputEncoding; - try - { - Console.InputEncoding = noBomUTF8; - _processStarted = _process.Start(); - } - finally - { - Console.InputEncoding = originalInputEncoding; - } -#endif - - if (!_processStarted) - { - _logger.TransportProcessStartFailed(EndpointName); - throw new McpTransportException("Failed to start MCP server process"); - } - - _logger.TransportProcessStarted(EndpointName, _process.Id); - - _process.BeginErrorReadLine(); - - // Start reading messages in the background - _readTask = Task.Run(() => ReadMessagesAsync(_shutdownCts.Token), CancellationToken.None); - _logger.TransportReadingMessages(EndpointName); - - SetConnected(true); + await streamTransport.ConnectAsync(cancellationToken).ConfigureAwait(false); + return streamTransport; } - catch (Exception ex) + catch { - _logger.TransportConnectFailed(EndpointName, ex); - await CleanupAsync(cancellationToken).ConfigureAwait(false); - throw new McpTransportException("Failed to connect transport", ex); + await streamTransport.DisposeAsync().ConfigureAwait(false); + throw; } } - - /// - public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) - { - if (!IsConnected || _process?.HasExited == true) - { - _logger.TransportNotConnected(EndpointName); - throw new McpTransportException("Transport is not connected"); - } - - string id = "(no id)"; - if (message is IJsonRpcMessageWithId messageWithId) - { - id = messageWithId.Id.ToString(); - } - - try - { - var json = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo()); - _logger.TransportSendingMessage(EndpointName, id, json); - _logger.TransportMessageBytesUtf8(EndpointName, json); - - // Write the message followed by a newline using our UTF-8 writer - await _process!.StandardInput.WriteLineAsync(json).ConfigureAwait(false); - await _process.StandardInput.FlushAsync(cancellationToken).ConfigureAwait(false); - - _logger.TransportSentMessage(EndpointName, id); - } - catch (Exception ex) - { - _logger.TransportSendFailed(EndpointName, id, ex); - throw new McpTransportException("Failed to send message", ex); - } - } - - /// - public override async ValueTask DisposeAsync() - { - await CleanupAsync(CancellationToken.None).ConfigureAwait(false); - } - - private async Task ReadMessagesAsync(CancellationToken cancellationToken) - { - try - { - _logger.TransportEnteringReadMessagesLoop(EndpointName); - - while (!cancellationToken.IsCancellationRequested && !_process!.HasExited) - { - _logger.TransportWaitingForMessage(EndpointName); - var line = await _process.StandardOutput.ReadLineAsync(cancellationToken).ConfigureAwait(false); - if (line == null) - { - _logger.TransportEndOfStream(EndpointName); - break; - } - - if (string.IsNullOrWhiteSpace(line)) - { - continue; - } - - _logger.TransportReceivedMessage(EndpointName, line); - _logger.TransportMessageBytesUtf8(EndpointName, line); - - await ProcessMessageAsync(line, cancellationToken).ConfigureAwait(false); - } - _logger.TransportExitingReadMessagesLoop(EndpointName); - } - catch (OperationCanceledException) - { - _logger.TransportReadMessagesCancelled(EndpointName); - // Normal shutdown - } - catch (Exception ex) - { - _logger.TransportReadMessagesFailed(EndpointName, ex); - } - finally - { - await CleanupAsync(cancellationToken).ConfigureAwait(false); - } - } - - private async Task ProcessMessageAsync(string line, CancellationToken cancellationToken) - { - try - { - line=line.Trim();//Fixes an error when the service prefixes nonprintable characters - var message = JsonSerializer.Deserialize(line, _jsonOptions.GetTypeInfo()); - if (message != null) - { - string messageId = "(no id)"; - if (message is IJsonRpcMessageWithId messageWithId) - { - messageId = messageWithId.Id.ToString(); - } - _logger.TransportReceivedMessageParsed(EndpointName, messageId); - await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); - _logger.TransportMessageWritten(EndpointName, messageId); - } - else - { - _logger.TransportMessageParseUnexpectedType(EndpointName, line); - } - } - catch (JsonException ex) - { - _logger.TransportMessageParseFailed(EndpointName, line, ex); - } - } - - private async Task CleanupAsync(CancellationToken cancellationToken) - { - _logger.TransportCleaningUp(EndpointName); - - if (_process is Process process && _processStarted && !process.HasExited) - { - try - { - // Wait for the process to exit - _logger.TransportWaitingForShutdown(EndpointName); - - // Kill the while process tree because the process may spawn child processes - // and Node.js does not kill its children when it exits properly - process.KillTree(_options.ShutdownTimeout); - } - catch (Exception ex) - { - _logger.TransportShutdownFailed(EndpointName, ex); - } - finally - { - process.Dispose(); - _process = null; - } - } - - if (_shutdownCts is { } shutdownCts) - { - await shutdownCts.CancelAsync().ConfigureAwait(false); - shutdownCts.Dispose(); - _shutdownCts = null; - } - - if (_readTask is Task readTask) - { - try - { - _logger.TransportWaitingForReadTask(EndpointName); - await readTask.WaitAsync(TimeSpan.FromSeconds(5), cancellationToken).ConfigureAwait(false); - } - catch (TimeoutException) - { - _logger.TransportCleanupReadTaskTimeout(EndpointName); - } - catch (OperationCanceledException) - { - _logger.TransportCleanupReadTaskCancelled(EndpointName); - } - catch (Exception ex) - { - _logger.TransportCleanupReadTaskFailed(EndpointName, ex); - } - finally - { - _logger.TransportReadTaskCleanedUp(EndpointName); - _readTask = null; - } - } - - SetConnected(false); - _logger.TransportCleanedUp(EndpointName); - } - } diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs index 05cd5a28..7779edc9 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs @@ -1,4 +1,4 @@ -using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Options; using ModelContextProtocol.Logging; @@ -14,7 +14,7 @@ namespace ModelContextProtocol.Protocol.Transport; /// /// Provides an implementation of the MCP transport protocol over standard input/output streams. /// -public sealed class StdioServerTransport : TransportBase, IServerTransport +public sealed class StdioServerTransport : TransportBase, ITransport { private static readonly byte[] s_newlineBytes = "\n"u8.ToArray(); @@ -28,8 +28,7 @@ public sealed class StdioServerTransport : TransportBase, IServerTransport private readonly SemaphoreSlim _sendLock = new(1, 1); private readonly CancellationTokenSource _shutdownCts = new(); - private Task _readLoopCompleted = Task.CompletedTask; - private TaskCompletionSource _serverCompleted = new(); + private readonly Task _readLoopCompleted; private int _disposed = 0; private string EndpointName => $"Server (stdio) ({_serverName})"; @@ -119,24 +118,11 @@ public StdioServerTransport(string serverName, Stream? stdinStream = null, Strea _stdInReader = new StreamReader(stdinStream ?? Console.OpenStandardInput(), Encoding.UTF8); _stdOutStream = stdoutStream ?? new BufferedStream(Console.OpenStandardOutput()); - } - - /// - public Task StartListeningAsync(CancellationToken cancellationToken = default) - { - _logger.LogDebug("Starting StdioServerTransport listener for {EndpointName}", EndpointName); SetConnected(true); _readLoopCompleted = Task.Run(ReadMessagesAsync, _shutdownCts.Token); - - _logger.LogDebug("StdioServerTransport now connected for {EndpointName}", EndpointName); - - return Task.CompletedTask; } - /// Gets a that completes when the server transport has finished its work. - public Task Completion => _serverCompleted.Task; - /// public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) { @@ -225,19 +211,17 @@ private async Task ReadMessagesAsync() _logger.TransportExitingReadMessagesLoop(EndpointName); } - catch (OperationCanceledException oce) + catch (OperationCanceledException) { _logger.TransportReadMessagesCancelled(EndpointName); - _serverCompleted.TrySetCanceled(oce.CancellationToken); } catch (Exception ex) { _logger.TransportReadMessagesFailed(EndpointName, ex); - _serverCompleted.TrySetException(ex); } finally { - _serverCompleted.TrySetResult(true); + SetConnected(false); } } @@ -288,4 +272,4 @@ public override async ValueTask DisposeAsync() _logger.TransportCleanedUp(EndpointName); } } -} \ No newline at end of file +} diff --git a/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs b/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs index 349b16c2..35576168 100644 --- a/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs +++ b/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs @@ -26,7 +26,7 @@ protected TransportBase(ILoggerFactory? loggerFactory) SingleReader = true, SingleWriter = true, }); - _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; + _logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance; } /// diff --git a/src/ModelContextProtocol/Server/IMcpServer.cs b/src/ModelContextProtocol/Server/IMcpServer.cs index e311b04a..e8dffaf1 100644 --- a/src/ModelContextProtocol/Server/IMcpServer.cs +++ b/src/ModelContextProtocol/Server/IMcpServer.cs @@ -45,21 +45,21 @@ public interface IMcpServer : IAsyncDisposable /// /// Runs the server, listening for and handling client requests. /// - Task RunAsync(Func? onInitialized = null, CancellationToken cancellationToken = default); + Task RunAsync(CancellationToken cancellationToken = default); /// /// Sends a generic JSON-RPC request to the client. - /// NB! This is a temporary method that is available to send not yet implemented feature messages. + /// NB! This is a temporary method that is available to send not yet implemented feature messages. /// Once all MCP features are implemented this will be made private, as it is purely a convenience for those who wish to implement features ahead of the library. /// - /// The expected response type. + /// The expected response type. /// The JSON-RPC request to send. /// A token to cancel the operation. /// A task containing the client's response. - Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) where T : class; + Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) where TResult : class; /// - /// Sends a message to the server. + /// Sends a message to the client. /// /// The message. /// A token to cancel the operation. diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index 5a8f50f6..bcfa0a0e 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -12,31 +12,67 @@ namespace ModelContextProtocol.Server; /// internal sealed class McpServer : McpJsonRpcEndpoint, IMcpServer { - private readonly string _serverDescription; + private readonly IServerTransport? _serverTransport; private readonly EventHandler? _toolsChangedDelegate; - private volatile bool _ran; + private ITransport? _sessionTransport; + private string _endpointName; /// /// Creates a new instance of . /// - /// Transport to use for the server - /// Configuration options for this server, including capabilities. - /// Make sure to accurately reflect exactly what capabilities the server supports and does not support. + /// Transport to use for the server that is ready to accept new sessions asynchronously. + /// Configuration options for this server, including capabilities. + /// Make sure to accurately reflect exactly what capabilities the server supports and does not support. + /// Logger factory to use for logging + /// Optional service provider to use for dependency injection + /// + public McpServer(IServerTransport serverTransport, McpServerOptions options, ILoggerFactory? loggerFactory, IServiceProvider? serviceProvider) + : this(options, loggerFactory, serviceProvider) + { + Throw.IfNull(serverTransport); + + _serverTransport = serverTransport; + } + + /// + /// Creates a new instance of . + /// + /// Transport to use for the server representing an already-established session. + /// Configuration options for this server, including capabilities. + /// Make sure to accurately reflect exactly what capabilities the server supports and does not support. /// Logger factory to use for logging /// Optional service provider to use for dependency injection /// public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? loggerFactory, IServiceProvider? serviceProvider) - : base(transport, loggerFactory) + : this(options, loggerFactory, serviceProvider) + { + Throw.IfNull(transport); + + _sessionTransport = transport; + InitializeSession(transport); + } + + /// + /// Creates a new instance of . + /// + /// Configuration options for this server, including capabilities. + /// Make sure to accurately reflect exactly what capabilities the server supports and does not support. + /// Logger factory to use for logging + /// Optional service provider to use for dependency injection + /// + private McpServer(McpServerOptions options, ILoggerFactory? loggerFactory, IServiceProvider? serviceProvider) + : base(loggerFactory) { Throw.IfNull(options); ServerOptions = options; Services = serviceProvider; - _serverDescription = $"{options.ServerInfo.Name} {options.ServerInfo.Version}"; + _endpointName = $"Server ({options.ServerInfo.Name} {options.ServerInfo.Version})"; + _toolsChangedDelegate = delegate { - _ = SendMessageAsync(new JsonRpcNotification() + _ = _session?.SendMessageAsync(new JsonRpcNotification() { Method = NotificationMethods.ToolListChangedNotification, }); @@ -77,60 +113,63 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? public IServiceProvider? Services { get; } /// - public override string EndpointName => - $"Server ({_serverDescription}), Client ({ClientInfo?.Name} {ClientInfo?.Version})"; + public override string EndpointName => _endpointName; - /// - public async Task RunAsync(Func? onInitialized = null, CancellationToken cancellationToken = default) + public async Task AcceptSessionAsync(CancellationToken cancellationToken = default) { - if (_ran) - { - _logger.ServerAlreadyInitializing(EndpointName); - throw new InvalidOperationException("Server is already running."); - } - _ran = true; + // Below is effectively an assertion. The McpServerFactory should only use this with the IServerTransport constructor. + Throw.IfNull(_serverTransport); try { - CancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + _sessionTransport = await _serverTransport.AcceptAsync(cancellationToken).ConfigureAwait(false); - IServerTransport? serverTransport = Transport as IServerTransport; - if (serverTransport is not null) + if (_sessionTransport is null) { - // Start listening for messages - await serverTransport.StartListeningAsync(CancellationTokenSource.Token).ConfigureAwait(false); + throw new McpServerException("The server transport closed before a client started a new session."); } - // Start processing messages - MessageProcessingTask = ProcessMessagesAsync(CancellationTokenSource.Token); - - // Invoke any post-initialization logic. - if (onInitialized is not null) - { - await onInitialized(CancellationTokenSource.Token).ConfigureAwait(false); - } - - // Wait for transport completion. - if (serverTransport is not null) - { - await serverTransport.Completion; - } + InitializeSession(_sessionTransport); } catch (Exception e) { _logger.ServerInitializationError(EndpointName, e); - await CleanupAsync().ConfigureAwait(false); throw; } } - protected override Task CleanupAsync() + /// + public async Task RunAsync(CancellationToken cancellationToken = default) + { + // Below is effectively an assertion. The McpServerFactory should not return before the _transport is initialized. + Throw.IfNull(_session); + + try + { + // Start processing messages + StartSession(fullSessionCancellationToken: cancellationToken); + await MessageProcessingTask.ConfigureAwait(false); + } + finally + { + await DisposeAsync().ConfigureAwait(false); + } + } + + public override async ValueTask DisposeAsync() { if (ServerOptions.Capabilities?.Tools?.ToolCollection is { } tools) { tools.Changed -= _toolsChangedDelegate; } - return base.CleanupAsync(); + + await base.DisposeAsync().ConfigureAwait(false); + + if (_serverTransport is not null && _sessionTransport is not null) + { + // We created the _sessionTransport from the _serverTransport, so we own it. + await _sessionTransport.DisposeAsync().ConfigureAwait(false); + } } private void SetPingHandler() @@ -146,6 +185,11 @@ private void SetInitializeHandler(McpServerOptions options) { ClientCapabilities = request?.Capabilities ?? new(); ClientInfo = request?.ClientInfo; + + // Use the ClientInfo to update the session EndpointName for logging. + _endpointName = $"{_endpointName}, Client ({ClientInfo?.Name} {ClientInfo?.Version})"; + GetSessionOrThrow().EndpointName = EndpointName; + return Task.FromResult(new InitializeResult() { ProtocolVersion = options.ProtocolVersion, diff --git a/src/ModelContextProtocol/Server/McpServerFactory.cs b/src/ModelContextProtocol/Server/McpServerFactory.cs index ce79a23d..953b7430 100644 --- a/src/ModelContextProtocol/Server/McpServerFactory.cs +++ b/src/ModelContextProtocol/Server/McpServerFactory.cs @@ -10,27 +10,65 @@ namespace ModelContextProtocol.Server; public static class McpServerFactory { /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// - /// Transport to use for the server + /// Transport to use for the server representing an already-established MCP session. /// /// Configuration options for this server, including capabilities. /// Make sure to accurately reflect exactly what capabilities the server supports and does not support. /// - /// Optional service provider to create new instances. /// Logger factory to use for logging - /// An that's started and ready to receive connections. - /// is . + /// Optional service provider to create new instances. + /// An . + /// is . /// is . public static IMcpServer Create( - ITransport serverTransport, + ITransport transport, McpServerOptions serverOptions, ILoggerFactory? loggerFactory = null, IServiceProvider? serviceProvider = null) + { + Throw.IfNull(transport); + Throw.IfNull(serverOptions); + + return new McpServer(transport, serverOptions, loggerFactory, serviceProvider); + } + + /// + /// Waits for the client to establish a new MCP session, then initializes a new instance of the class. + /// + /// Transport to use for the server that is ready to accept new MCP sessions asynchronously. + /// + /// Configuration options for this server, including capabilities. + /// Make sure to accurately reflect exactly what capabilities the server supports and does not support. + /// + /// Logger factory to use for logging + /// Optional service provider to create new instances. + /// Cancel waiting for a client to establish a new MCP session. + /// An . + /// is . + /// is . + public static async Task AcceptAsync( + IServerTransport serverTransport, + McpServerOptions serverOptions, + ILoggerFactory? loggerFactory = null, + IServiceProvider? serviceProvider = null, + CancellationToken cancellationToken = default) { Throw.IfNull(serverTransport); Throw.IfNull(serverOptions); - return new McpServer(serverTransport, serverOptions, loggerFactory, serviceProvider); + var mcpServer = new McpServer(serverTransport, serverOptions, loggerFactory, serviceProvider); + + try + { + await mcpServer.AcceptSessionAsync(cancellationToken).ConfigureAwait(false); + return mcpServer; + } + catch + { + await mcpServer.DisposeAsync().ConfigureAwait(false); + throw; + } } } diff --git a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs b/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs index 7052953d..971f7dce 100644 --- a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs +++ b/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs @@ -1,15 +1,9 @@ -using ModelContextProtocol.Client; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Logging; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; -using ModelContextProtocol.Utils; -using ModelContextProtocol.Utils.Json; - -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; - -using System.Collections.Concurrent; -using System.Text.Json; +using System.Diagnostics.CodeAnalysis; namespace ModelContextProtocol.Shared; @@ -22,352 +16,84 @@ namespace ModelContextProtocol.Shared; /// internal abstract class McpJsonRpcEndpoint : IAsyncDisposable { - private readonly ITransport _transport; - private readonly ConcurrentDictionary> _pendingRequests; - private readonly ConcurrentDictionary>> _notificationHandlers; - private readonly Dictionary>> _requestHandlers = []; - private int _nextRequestId; - private readonly JsonSerializerOptions _jsonOptions; - private bool _isDisposed; + private readonly RequestHandlers _requestHandlers = []; + private readonly NotificationHandlers _notificationHandlers = []; + + private CancellationTokenSource? _sessionCts; + private int _started; + private int _disposed; protected readonly ILogger _logger; + protected McpSession? _session; /// /// Initializes a new instance of the class. /// - /// An MCP transport implementation. /// The logger factory. - protected McpJsonRpcEndpoint(ITransport transport, ILoggerFactory? loggerFactory = null) + protected McpJsonRpcEndpoint(ILoggerFactory? loggerFactory = null) { - Throw.IfNull(transport); - - _transport = transport; - _pendingRequests = new(); - _notificationHandlers = new(); - _nextRequestId = 1; - _jsonOptions = McpJsonUtilities.DefaultOptions; _logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance; } - /// - /// Gets the name of the endpoint for logging and debug purposes. - /// - public abstract string EndpointName { get; } - - /// - /// Gets the transport implementation for the endpoint. Should generally not be needed outside of tests. - /// Sub-classes should store IClientTransport or IServerTransport injected during construction instead of casting this field. - /// - internal ITransport Transport => _transport; - - /// - /// Starts processing messages from the transport. This method will block until the transport is disconnected. - /// This is generally started in a background task or thread from the initialization logic of the derived class. - /// - internal async Task ProcessMessagesAsync(CancellationToken cancellationToken) - { - try - { - await foreach (var message in _transport.MessageReader.ReadAllAsync(cancellationToken).ConfigureAwait(false)) - { - _logger.TransportMessageRead(EndpointName, message.GetType().Name); - - // Fire and forget the message handling task to avoid blocking the transport - // If awaiting the task, the transport will not be able to read more messages, - // which could lead to a deadlock if the handler sends a message back - _ = ProcessMessageAsync(); - async Task ProcessMessageAsync() - { -#if NET - await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding); -#else - await default(ForceYielding); -#endif - try - { - await HandleMessageAsync(message, cancellationToken).ConfigureAwait(false); - } - catch (Exception ex) - { - var payload = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo()); - _logger.MessageHandlerError(EndpointName, message.GetType().Name, payload, ex); - } - } - } - } - catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) - { - // Normal shutdown - _logger.EndpointMessageProcessingCancelled(EndpointName); - } - catch (NullReferenceException) - { - // Ignore reader disposal and mocked transport - } - } - - private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken) - { - switch (message) - { - case JsonRpcRequest request: - await HandleRequest(request, cancellationToken).ConfigureAwait(false); - break; - - case IJsonRpcMessageWithId messageWithId: - HandleMessageWithId(message, messageWithId); - break; - - case JsonRpcNotification notification: - await HandleNotification(notification).ConfigureAwait(false); - break; - - default: - _logger.EndpointHandlerUnexpectedMessageType(EndpointName, message.GetType().Name); - break; - } - } + protected void SetRequestHandler(string method, Func> handler) + => _requestHandlers.Set(method, handler); - private async Task HandleNotification(JsonRpcNotification notification) - { - if (_notificationHandlers.TryGetValue(notification.Method, out var handlers)) - { - foreach (var notificationHandler in handlers) - { - try - { - await notificationHandler(notification).ConfigureAwait(false); - } - catch (Exception ex) - { - // Log handler error but continue processing - _logger.NotificationHandlerError(EndpointName, notification.Method, ex); - } - } - } - } + public void AddNotificationHandler(string method, Func handler) + => _notificationHandlers.Add(method, handler); - private void HandleMessageWithId(IJsonRpcMessage message, IJsonRpcMessageWithId messageWithId) - { - if (!messageWithId.Id.IsValid) - { - _logger.RequestHasInvalidId(EndpointName); - } - else if (_pendingRequests.TryRemove(messageWithId.Id, out var tcs)) - { - _logger.ResponseMatchedPendingRequest(EndpointName, messageWithId.Id.ToString()); - tcs.TrySetResult(message); - } - else - { - _logger.NoRequestFoundForMessageWithId(EndpointName, messageWithId.Id.ToString()); - } - } + public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) where TResult : class + => GetSessionOrThrow().SendRequestAsync(request, cancellationToken); - private async Task HandleRequest(JsonRpcRequest request, CancellationToken cancellationToken) - { - if (_requestHandlers.TryGetValue(request.Method, out var handler)) - { - try - { - _logger.RequestHandlerCalled(EndpointName, request.Method); - var result = await handler(request, cancellationToken).ConfigureAwait(false); - _logger.RequestHandlerCompleted(EndpointName, request.Method); - await _transport.SendMessageAsync(new JsonRpcResponse - { - Id = request.Id, - JsonRpc = "2.0", - Result = result - }, cancellationToken).ConfigureAwait(false); - } - catch (Exception ex) - { - _logger.RequestHandlerError(EndpointName, request.Method, ex); - // Send error response - await _transport.SendMessageAsync(new JsonRpcError - { - Id = request.Id, - JsonRpc = "2.0", - Error = new JsonRpcErrorDetail - { - Code = -32000, // Implementation defined error - Message = ex.Message - } - }, cancellationToken).ConfigureAwait(false); - } - } - else - { - _logger.NoHandlerFoundForRequest(EndpointName, request.Method); - } - } + public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + => GetSessionOrThrow().SendMessageAsync(message, cancellationToken); /// - /// Sends a generic JSON-RPC request to the server. - /// It is strongly recommended use the capability-specific methods instead of this one. - /// Use this method for custom requests or those not yet covered explicitly by the endpoint implementation. + /// Gets the name of the endpoint for logging and debug purposes. /// - /// The expected response type. - /// The JSON-RPC request to send. - /// A token to cancel the operation. - /// A task containing the server's response. - public async Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) where TResult : class - { - if (!_transport.IsConnected) - { - _logger.EndpointNotConnected(EndpointName); - throw new McpClientException("Transport is not connected"); - } - - // Set request ID - request.Id = RequestId.FromNumber(Interlocked.Increment(ref _nextRequestId)); - - var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - _pendingRequests[request.Id] = tcs; - - try - { - // Expensive logging, use the logging framework to check if the logger is enabled - if (_logger.IsEnabled(LogLevel.Debug)) - { - _logger.SendingRequestPayload(EndpointName, JsonSerializer.Serialize(request, _jsonOptions.GetTypeInfo())); - } - - // Less expensive information logging - _logger.SendingRequest(EndpointName, request.Method); - - await _transport.SendMessageAsync(request, cancellationToken).ConfigureAwait(false); - - _logger.RequestSentAwaitingResponse(EndpointName, request.Method, request.Id.ToString()); - var response = await tcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false); - - if (response is JsonRpcError error) - { - _logger.RequestFailed(EndpointName, request.Method, error.Error.Message, error.Error.Code); - throw new McpClientException($"Request failed (server side): {error.Error.Message}", error.Error.Code); - } - - if (response is JsonRpcResponse success) - { - // Convert the Result object to JSON and back to get our strongly-typed result - var resultJson = JsonSerializer.Serialize(success.Result, _jsonOptions.GetTypeInfo()); - var resultObject = JsonSerializer.Deserialize(resultJson, _jsonOptions.GetTypeInfo()); - - // Not expensive logging because we're already converting to JSON in order to get the result object - _logger.RequestResponseReceivedPayload(EndpointName, resultJson); - _logger.RequestResponseReceived(EndpointName, request.Method); - - if (resultObject != null) - { - return resultObject; - } - - // Result object was null, this is unexpected - _logger.RequestResponseTypeConversionError(EndpointName, request.Method, typeof(TResult)); - throw new McpClientException($"Unexpected response type {JsonSerializer.Serialize(success.Result, _jsonOptions.GetTypeInfo())}, expected {typeof(TResult)}"); - } - - // Unexpected response type - _logger.RequestInvalidResponseType(EndpointName, request.Method); - throw new McpClientException("Invalid response type"); - } - finally - { - _pendingRequests.TryRemove(request.Id, out _); - } - } - - public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) - { - Throw.IfNull(message); - - if (!_transport.IsConnected) - { - _logger.ClientNotConnected(EndpointName); - throw new McpClientException("Transport is not connected"); - } - - if (_logger.IsEnabled(LogLevel.Debug)) - { - _logger.SendingMessage(EndpointName, JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo())); - } - - return _transport.SendMessageAsync(message, cancellationToken); - } + public abstract string EndpointName { get; } /// - /// Registers a handler for incoming notifications of a specific method. - /// - /// Constants for common notification methods + /// Task that processes incoming messages from the transport. /// - /// The notification method to handle. - /// The async handler function to process notifications. - public void AddNotificationHandler(string method, Func handler) - { - var handlers = _notificationHandlers.GetOrAdd(method, _ => []); - lock (handlers) - { - handlers.Add(handler); - } - } + protected Task? MessageProcessingTask { get; set; } - /// - public async ValueTask DisposeAsync() + protected void InitializeSession(ITransport sessionTransport) { - await CleanupAsync().ConfigureAwait(false); + _session = new McpSession(sessionTransport, EndpointName, _requestHandlers, _notificationHandlers, _logger); } - /// - /// Registers a handler for incoming requests of a specific method. - /// - /// Type of request payload - /// Type of response payload (not full RPC response - /// Method identifier to register for - /// Handler to be called when a request with specified method identifier is received - protected void SetRequestHandler(string method, Func> handler) + [MemberNotNull(nameof(MessageProcessingTask))] + protected void StartSession(CancellationToken fullSessionCancellationToken = default) { - Throw.IfNull(method); - Throw.IfNull(handler); - - _requestHandlers[method] = async (request, cancellationToken) => + if (Interlocked.Exchange(ref _started, 1) != 0) { - // Convert the params JsonElement to our type using the same options - var jsonString = JsonSerializer.Serialize(request.Params, _jsonOptions.GetTypeInfo()); - var typedRequest = JsonSerializer.Deserialize(jsonString, _jsonOptions.GetTypeInfo()); + throw new InvalidOperationException("The MCP session has already stared."); + } - return await handler(typedRequest, cancellationToken).ConfigureAwait(false); - }; + _sessionCts = CancellationTokenSource.CreateLinkedTokenSource(fullSessionCancellationToken); + MessageProcessingTask = GetSessionOrThrow().ProcessMessagesAsync(_sessionCts.Token); } - /// - /// Task that processes incoming messages from the transport. - /// - protected Task? MessageProcessingTask { get; set; } - - /// - /// CancellationTokenSource used to cancel the message processing task. - /// - protected CancellationTokenSource? CancellationTokenSource { get; set; } - /// /// Cleans up the endpoint and releases resources. /// /// - protected virtual async Task CleanupAsync() + public virtual async ValueTask DisposeAsync() { - if (_isDisposed) + if (Interlocked.Exchange(ref _disposed, 1) != 0) + { + // TODO: It's more correct to await the last DisposeAsync before returning if it's still ongoing. return; - - _isDisposed = true; + } _logger.CleaningUpEndpoint(EndpointName); - if (CancellationTokenSource != null) + if (_sessionCts is not null) { - await CancellationTokenSource.CancelAsync().ConfigureAwait(false); + await _sessionCts.CancelAsync().ConfigureAwait(false); } - if (MessageProcessingTask != null) + if (MessageProcessingTask is not null) { try { @@ -379,16 +105,12 @@ protected virtual async Task CleanupAsync() } } - // Complete all pending requests with cancellation - foreach (var entry in _pendingRequests) - { - entry.Value.TrySetCanceled(); - } - _pendingRequests.Clear(); - - await _transport.DisposeAsync().ConfigureAwait(false); - CancellationTokenSource?.Dispose(); + _session?.Dispose(); + _sessionCts?.Dispose(); _logger.EndpointCleanedUp(EndpointName); } + + protected McpSession GetSessionOrThrow() => + _session ?? throw new InvalidOperationException($"This should be unreachable from public API! Call {nameof(InitializeSession)} before sending messages."); } diff --git a/src/ModelContextProtocol/Shared/McpSession.cs b/src/ModelContextProtocol/Shared/McpSession.cs new file mode 100644 index 00000000..1d9c8d0c --- /dev/null +++ b/src/ModelContextProtocol/Shared/McpSession.cs @@ -0,0 +1,299 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Client; +using ModelContextProtocol.Logging; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Utils; +using ModelContextProtocol.Utils.Json; +using System.Collections.Concurrent; +using System.Text.Json; + +namespace ModelContextProtocol.Shared; + +/// +/// Class for managing an MCP JSON-RPC session. This covers both MCP clients and servers. +/// +internal sealed class McpSession : IDisposable +{ + private readonly ITransport _transport; + private readonly RequestHandlers _requestHandlers; + private readonly NotificationHandlers _notificationHandlers = []; + + private readonly ConcurrentDictionary> _pendingRequests = []; + private readonly JsonSerializerOptions _jsonOptions; + private readonly ILogger _logger; + + private int _nextRequestId; + + /// + /// Initializes a new instance of the class. + /// + /// An MCP transport implementation. + /// The name of the endpoint for logging and debug purposes. + /// A collection of request handlers. + /// A collection of notification handlers. + /// The logger. + public McpSession( + ITransport transport, + string endpointName, + RequestHandlers requestHandlers, + NotificationHandlers notificationHandlers, + ILogger logger) + { + Throw.IfNull(transport); + + _transport = transport; + EndpointName = endpointName; + _requestHandlers = requestHandlers; + _notificationHandlers = notificationHandlers; + _jsonOptions = McpJsonUtilities.DefaultOptions; + _logger = logger ?? NullLogger.Instance; + } + + /// + /// Gets and sets the name of the endpoint for logging and debug purposes. + /// + public string EndpointName { get; set; } + + /// + /// Starts processing messages from the transport. This method will block until the transport is disconnected. + /// This is generally started in a background task or thread from the initialization logic of the derived class. + /// + public async Task ProcessMessagesAsync(CancellationToken cancellationToken) + { + try + { + await foreach (var message in _transport.MessageReader.ReadAllAsync(cancellationToken).ConfigureAwait(false)) + { + _logger.TransportMessageRead(EndpointName, message.GetType().Name); + + // Fire and forget the message handling task to avoid blocking the transport + // If awaiting the task, the transport will not be able to read more messages, + // which could lead to a deadlock if the handler sends a message back + _ = ProcessMessageAsync(); + async Task ProcessMessageAsync() + { +#if NET + await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding); +#else + await default(ForceYielding); +#endif + try + { + await HandleMessageAsync(message, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + var payload = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo()); + _logger.MessageHandlerError(EndpointName, message.GetType().Name, payload, ex); + } + } + } + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + // Normal shutdown + _logger.EndpointMessageProcessingCancelled(EndpointName); + } + } + + private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken) + { + switch (message) + { + case JsonRpcRequest request: + await HandleRequest(request, cancellationToken).ConfigureAwait(false); + break; + + case IJsonRpcMessageWithId messageWithId: + HandleMessageWithId(message, messageWithId); + break; + + case JsonRpcNotification notification: + await HandleNotification(notification).ConfigureAwait(false); + break; + + default: + _logger.EndpointHandlerUnexpectedMessageType(EndpointName, message.GetType().Name); + break; + } + } + + private async Task HandleNotification(JsonRpcNotification notification) + { + if (_notificationHandlers.TryGetValue(notification.Method, out var handlers)) + { + foreach (var notificationHandler in handlers) + { + try + { + await notificationHandler(notification).ConfigureAwait(false); + } + catch (Exception ex) + { + // Log handler error but continue processing + _logger.NotificationHandlerError(EndpointName, notification.Method, ex); + } + } + } + } + + private void HandleMessageWithId(IJsonRpcMessage message, IJsonRpcMessageWithId messageWithId) + { + if (!messageWithId.Id.IsValid) + { + _logger.RequestHasInvalidId(EndpointName); + } + else if (_pendingRequests.TryRemove(messageWithId.Id, out var tcs)) + { + _logger.ResponseMatchedPendingRequest(EndpointName, messageWithId.Id.ToString()); + tcs.TrySetResult(message); + } + else + { + _logger.NoRequestFoundForMessageWithId(EndpointName, messageWithId.Id.ToString()); + } + } + + private async Task HandleRequest(JsonRpcRequest request, CancellationToken cancellationToken) + { + if (_requestHandlers.TryGetValue(request.Method, out var handler)) + { + try + { + _logger.RequestHandlerCalled(EndpointName, request.Method); + var result = await handler(request, cancellationToken).ConfigureAwait(false); + _logger.RequestHandlerCompleted(EndpointName, request.Method); + await _transport.SendMessageAsync(new JsonRpcResponse + { + Id = request.Id, + JsonRpc = "2.0", + Result = result + }, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + _logger.RequestHandlerError(EndpointName, request.Method, ex); + // Send error response + await _transport.SendMessageAsync(new JsonRpcError + { + Id = request.Id, + JsonRpc = "2.0", + Error = new JsonRpcErrorDetail + { + Code = -32000, // Implementation defined error + Message = ex.Message + } + }, cancellationToken).ConfigureAwait(false); + } + } + else + { + _logger.NoHandlerFoundForRequest(EndpointName, request.Method); + } + } + + /// + /// Sends a generic JSON-RPC request to the server. + /// It is strongly recommended use the capability-specific methods instead of this one. + /// Use this method for custom requests or those not yet covered explicitly by the endpoint implementation. + /// + /// The expected response type. + /// The JSON-RPC request to send. + /// A token to cancel the operation. + /// A task containing the server's response. + public async Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) where TResult : class + { + if (!_transport.IsConnected) + { + _logger.EndpointNotConnected(EndpointName); + throw new McpClientException("Transport is not connected"); + } + + // Set request ID + request.Id = RequestId.FromNumber(Interlocked.Increment(ref _nextRequestId)); + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _pendingRequests[request.Id] = tcs; + + try + { + // Expensive logging, use the logging framework to check if the logger is enabled + if (_logger.IsEnabled(LogLevel.Debug)) + { + _logger.SendingRequestPayload(EndpointName, JsonSerializer.Serialize(request, _jsonOptions.GetTypeInfo())); + } + + // Less expensive information logging + _logger.SendingRequest(EndpointName, request.Method); + + await _transport.SendMessageAsync(request, cancellationToken).ConfigureAwait(false); + + _logger.RequestSentAwaitingResponse(EndpointName, request.Method, request.Id.ToString()); + var response = await tcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false); + + if (response is JsonRpcError error) + { + _logger.RequestFailed(EndpointName, request.Method, error.Error.Message, error.Error.Code); + throw new McpClientException($"Request failed (server side): {error.Error.Message}", error.Error.Code); + } + + if (response is JsonRpcResponse success) + { + // Convert the Result object to JSON and back to get our strongly-typed result + var resultJson = JsonSerializer.Serialize(success.Result, _jsonOptions.GetTypeInfo()); + var resultObject = JsonSerializer.Deserialize(resultJson, _jsonOptions.GetTypeInfo()); + + // Not expensive logging because we're already converting to JSON in order to get the result object + _logger.RequestResponseReceivedPayload(EndpointName, resultJson); + _logger.RequestResponseReceived(EndpointName, request.Method); + + if (resultObject != null) + { + return resultObject; + } + + // Result object was null, this is unexpected + _logger.RequestResponseTypeConversionError(EndpointName, request.Method, typeof(TResult)); + throw new McpClientException($"Unexpected response type {JsonSerializer.Serialize(success.Result, _jsonOptions.GetTypeInfo())}, expected {typeof(TResult)}"); + } + + // Unexpected response type + _logger.RequestInvalidResponseType(EndpointName, request.Method); + throw new McpClientException("Invalid response type"); + } + finally + { + _pendingRequests.TryRemove(request.Id, out _); + } + } + + public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + { + Throw.IfNull(message); + + if (!_transport.IsConnected) + { + _logger.ClientNotConnected(EndpointName); + throw new McpClientException("Transport is not connected"); + } + + if (_logger.IsEnabled(LogLevel.Debug)) + { + _logger.SendingMessage(EndpointName, JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo())); + } + + return _transport.SendMessageAsync(message, cancellationToken); + } + + public void Dispose() + { + // Complete all pending requests with cancellation + foreach (var entry in _pendingRequests) + { + entry.Value.TrySetCanceled(); + } + _pendingRequests.Clear(); + } +} diff --git a/src/ModelContextProtocol/Shared/NotificationHandlers.cs b/src/ModelContextProtocol/Shared/NotificationHandlers.cs new file mode 100644 index 00000000..1fdb0578 --- /dev/null +++ b/src/ModelContextProtocol/Shared/NotificationHandlers.cs @@ -0,0 +1,16 @@ +using ModelContextProtocol.Protocol.Messages; +using System.Collections.Concurrent; + +namespace ModelContextProtocol.Shared; + +internal sealed class NotificationHandlers : ConcurrentDictionary>> +{ + public void Add(string method, Func handler) + { + var handlers = GetOrAdd(method, _ => []); + lock (handlers) + { + handlers.Add(handler); + } + } +} diff --git a/src/ModelContextProtocol/Shared/RequestHandlers.cs b/src/ModelContextProtocol/Shared/RequestHandlers.cs new file mode 100644 index 00000000..be1f80c9 --- /dev/null +++ b/src/ModelContextProtocol/Shared/RequestHandlers.cs @@ -0,0 +1,31 @@ +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Utils; +using ModelContextProtocol.Utils.Json; +using System.Text.Json; + +namespace ModelContextProtocol.Shared; + +internal sealed class RequestHandlers : Dictionary>> +{ + /// + /// Registers a handler for incoming requests of a specific method. + /// + /// Type of request payload + /// Type of response payload (not full RPC response + /// Method identifier to register for + /// Handler to be called when a request with specified method identifier is received + public void Set(string method, Func> handler) + { + Throw.IfNull(method); + Throw.IfNull(handler); + + this[method] = async (request, cancellationToken) => + { + // Convert the params JsonElement to our type using the same options + var jsonString = JsonSerializer.Serialize(request.Params, McpJsonUtilities.DefaultOptions.GetTypeInfo()); + var typedRequest = JsonSerializer.Deserialize(jsonString, McpJsonUtilities.DefaultOptions.GetTypeInfo()); + + return await handler(typedRequest, cancellationToken).ConfigureAwait(false); + }; + } +} diff --git a/tests/ModelContextProtocol.TestServer/Program.cs b/tests/ModelContextProtocol.TestServer/Program.cs index 923a2b9d..6cb30a1e 100644 --- a/tests/ModelContextProtocol.TestServer/Program.cs +++ b/tests/ModelContextProtocol.TestServer/Program.cs @@ -48,52 +48,58 @@ private static async Task Main(string[] args) }; using var loggerFactory = CreateLoggerFactory(); - await using IMcpServer server = McpServerFactory.Create(new StdioServerTransport("TestServer", loggerFactory), options, loggerFactory); + await using var stdioTransport = new StdioServerTransport("TestServer", loggerFactory); + await using IMcpServer server = McpServerFactory.Create(stdioTransport, options, loggerFactory); Log.Logger.Information("Server running..."); - await server.RunAsync(async cancellationToken => - { - var loggingLevels = Enum.GetValues(); - // Run until process is stopped by the client (parent process) - while (true) + // Run until process is stopped by the client (parent process) + _ = RunBackgroundLoop(server); + + await server.RunAsync(); + } + + private static async Task RunBackgroundLoop(IMcpServer server, CancellationToken cancellationToken = default) + { + var loggingLevels = Enum.GetValues(); + + while (true) + { + await Task.Delay(1000, cancellationToken); + try { - await Task.Delay(1000, cancellationToken); - try + // Send random log messages every few seconds + if (_minimumLoggingLevel is not null) { - // Send random log messages every few seconds - if (_minimumLoggingLevel is not null) + var logLevel = loggingLevels[Random.Shared.Next(loggingLevels.Length)]; + await server.SendMessageAsync(new JsonRpcNotification() { - var logLevel = loggingLevels[Random.Shared.Next(loggingLevels.Length)]; - await server.SendMessageAsync(new JsonRpcNotification() + Method = NotificationMethods.LoggingMessageNotification, + Params = new LoggingMessageNotificationParams { - Method = NotificationMethods.LoggingMessageNotification, - Params = new LoggingMessageNotificationParams - { - Level = logLevel, - Data = JsonSerializer.Deserialize("\"Random log message\"") - } - }, cancellationToken); - } - - // Snapshot the subscribed resources, rather than locking while sending notifications - foreach (var resource in _subscribedResources) - { - ResourceUpdatedNotificationParams notificationParams = new() { Uri = resource.Key }; - await server.SendMessageAsync(new JsonRpcNotification() - { - Method = NotificationMethods.ResourceUpdatedNotification, - Params = notificationParams - }, cancellationToken); - } + Level = logLevel, + Data = JsonSerializer.Deserialize("\"Random log message\"") + } + }, cancellationToken); } - catch (Exception ex) + + // Snapshot the subscribed resources, rather than locking while sending notifications + foreach (var resource in _subscribedResources) { - Log.Logger.Error(ex, "Error sending log message"); - break; + ResourceUpdatedNotificationParams notificationParams = new() { Uri = resource.Key }; + await server.SendMessageAsync(new JsonRpcNotification() + { + Method = NotificationMethods.ResourceUpdatedNotification, + Params = notificationParams + }, cancellationToken); } } - }); + catch (Exception ex) + { + Log.Logger.Error(ex, "Error sending log message"); + break; + } + } } private static ToolsCapability ConfigureTools() diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index 98d5dfaf..df722a33 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -382,10 +382,15 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st }; loggerFactory ??= CreateLoggerFactory(); - await using IMcpServer server = McpServerFactory.Create(new HttpListenerSseServerTransport("TestServer", 3001, loggerFactory), options, loggerFactory); - + await using var httpListenerSseTransport = new HttpListenerSseServerTransport("TestServer", 3001, loggerFactory); Console.WriteLine("Server running..."); - await server.RunAsync(cancellationToken: cancellationToken); + + // Each IMcpServer represents a new SSE session. + while (true) + { + var server = await McpServerFactory.AcceptAsync(httpListenerSseTransport, options, loggerFactory, cancellationToken: cancellationToken); + _ = server.RunAsync(cancellationToken: cancellationToken); + } } const string MCP_TINY_IMAGE = diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs index ada15e41..591e3182 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs @@ -1,26 +1,31 @@ using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; using ModelContextProtocol.Client; using ModelContextProtocol.Configuration; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Transport; +using ModelContextProtocol.Tests.Utils; using System.IO.Pipelines; namespace ModelContextProtocol.Tests.Client; -public class McpClientExtensionsTests +public class McpClientExtensionsTests : LoggedTest { private readonly Pipe _clientToServerPipe = new(); private readonly Pipe _serverToClientPipe = new(); - private readonly IMcpServer _server; + private readonly ServiceProvider _serviceProvider; private readonly CancellationTokenSource _cts; private readonly Task _serverTask; - public McpClientExtensionsTests() + public McpClientExtensionsTests(ITestOutputHelper outputHelper) + : base(outputHelper) { ServiceCollection sc = new(); - sc.AddSingleton(new StdioServerTransport("TestServer", _clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream())); - sc.AddMcpServer(); + sc.AddSingleton(LoggerFactory); + sc.AddMcpServer().WithStdioServerTransport(); + // Call WithStdioServerTransport to get the IMcpServer registration, then overwrite default transport with a pipe transport. + sc.AddSingleton(new StdioServerTransport("TestServer", _clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream())); for (int f = 0; f < 10; f++) { string name = $"Method{f}"; @@ -28,19 +33,24 @@ public McpClientExtensionsTests() } sc.AddSingleton(McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)](string i) => $"{i} Result", new() { Name = "ValuesSetViaAttr" })); sc.AddSingleton(McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)](string i) => $"{i} Result", new() { Name = "ValuesSetViaOptions", Destructive = true, OpenWorld = false, ReadOnly = true })); - _server = sc.BuildServiceProvider().GetRequiredService(); + _serviceProvider = sc.BuildServiceProvider(); + var server = _serviceProvider.GetRequiredService(); _cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); - _serverTask = _server.RunAsync(cancellationToken: _cts.Token); + _serverTask = server.RunAsync(cancellationToken: _cts.Token); } public async ValueTask DisposeAsync() { - _cts.Cancel(); + await _cts.CancelAsync(); + _clientToServerPipe.Writer.Complete(); _serverToClientPipe.Writer.Complete(); + await _serverTask; - await _server.DisposeAsync(); + + await _serviceProvider.DisposeAsync(); + _cts.Dispose(); } private async Task CreateMcpClientForServer() @@ -58,6 +68,7 @@ private async Task CreateMcpClientForServer() return await McpClientFactory.CreateAsync( serverConfig, createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader), + loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); } diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs index 0053e8da..f208d615 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs @@ -187,7 +187,7 @@ public async Task McpFactory_WithInvalidTransportOptions_ThrowsFormatException(s await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(config, _defaultOptions, cancellationToken: TestContext.Current.CancellationToken)); } - private sealed class NopTransport : IClientTransport + private sealed class NopTransport : ITransport, IClientTransport { private readonly Channel _channel = Channel.CreateUnbounded(); @@ -195,8 +195,7 @@ private sealed class NopTransport : IClientTransport public ChannelReader MessageReader => _channel.Reader; - public Task ConnectAsync(CancellationToken cancellationToken = default) => - Task.CompletedTask; + public Task ConnectAsync(CancellationToken cancellationToken = default) => Task.FromResult(this); public ValueTask DisposeAsync() => default; diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 2d8181e8..edab8725 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -1,19 +1,19 @@ -using System.ComponentModel; -using System.Text.Json; -using ModelContextProtocol.Server; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; -using ModelContextProtocol.Protocol.Transport; -using System.IO.Pipelines; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; using ModelContextProtocol.Client; using ModelContextProtocol.Configuration; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Transport; +using ModelContextProtocol.Tests.Utils; +using System.ComponentModel; +using System.IO.Pipelines; +using System.Text.Json; using System.Text.RegularExpressions; -using Microsoft.Extensions.AI; using System.Threading.Channels; -using ModelContextProtocol.Protocol.Messages; -using Microsoft.Extensions.Options; -using ModelContextProtocol.Tests.Utils; -using Microsoft.Extensions.Logging; namespace ModelContextProtocol.Tests.Configuration; @@ -23,7 +23,6 @@ public class McpServerBuilderExtensionsToolsTests : LoggedTest, IAsyncDisposable private readonly Pipe _serverToClientPipe = new(); private readonly ServiceProvider _serviceProvider; private readonly IMcpServerBuilder _builder; - private readonly IMcpServer _server; private readonly CancellationTokenSource _cts; private readonly Task _serverTask; @@ -32,25 +31,28 @@ public McpServerBuilderExtensionsToolsTests(ITestOutputHelper testOutputHelper) { ServiceCollection sc = new(); sc.AddSingleton(LoggerFactory); - sc.AddSingleton(new StdioServerTransport("TestServer", _clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream())); + _builder = sc.AddMcpServer().WithStdioServerTransport().WithTools(); + // Call WithStdioServerTransport to get the IMcpServer registration, then overwrite default transport with a pipe transport. + sc.AddSingleton(new StdioServerTransport("TestServer", _clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream(), LoggerFactory)); sc.AddSingleton(new ObjectWithId()); - _builder = sc.AddMcpServer().WithTools(); _serviceProvider = sc.BuildServiceProvider(); - _server = _serviceProvider.GetRequiredService(); + var server = _serviceProvider.GetRequiredService(); _cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); - _serverTask = _server.RunAsync(cancellationToken: _cts.Token); + _serverTask = server.RunAsync(cancellationToken: _cts.Token); } public async ValueTask DisposeAsync() { await _cts.CancelAsync(); - _cts.Dispose(); _clientToServerPipe.Writer.Complete(); _serverToClientPipe.Writer.Complete(); - + + await _serverTask; + await _serviceProvider.DisposeAsync(); + _cts.Dispose(); } private async Task CreateMcpClientForServer() @@ -68,13 +70,15 @@ private async Task CreateMcpClientForServer() return await McpClientFactory.CreateAsync( serverConfig, createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader), + loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); } [Fact] public void Adds_Tools_To_Server() { - var tools = _server.ServerOptions?.Capabilities?.Tools?.ToolCollection; + var serverOptions = _serviceProvider.GetRequiredService>().Value; + var tools = serverOptions.Capabilities?.Tools?.ToolCollection; Assert.NotNull(tools); Assert.NotEmpty(tools); } @@ -112,27 +116,27 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T var stdinPipe = new Pipe(); var stdoutPipe = new Pipe(); - var transport = new StdioServerTransport($"TestServer_{i}", stdinPipe.Reader.AsStream(), stdoutPipe.Writer.AsStream()); - var server = McpServerFactory.Create(transport, options, loggerFactory, _serviceProvider); + await using var transport = new StdioServerTransport($"TestServer_{i}", stdinPipe.Reader.AsStream(), stdoutPipe.Writer.AsStream()); + await using var server = McpServerFactory.Create(transport, options, loggerFactory, _serviceProvider); + var serverRunTask = server.RunAsync(TestContext.Current.CancellationToken); - await server.RunAsync(async cancellationToken => + using var serverStdinWriter = new StreamWriter(stdinPipe.Writer.AsStream()); + using var serverStdoutReader = new StreamReader(stdoutPipe.Reader.AsStream()); + + var serverConfig = new McpServerConfig() + { + Id = $"TestServer_{i}", + Name = $"TestServer_{i}", + TransportType = "ignored", + }; + + await using (var client = await McpClientFactory.CreateAsync( + serverConfig, + createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader), + loggerFactory: LoggerFactory, + cancellationToken: TestContext.Current.CancellationToken)) { - using var serverStdinWriter = new StreamWriter(stdinPipe.Writer.AsStream()); - using var serverStdoutReader = new StreamReader(stdoutPipe.Reader.AsStream()); - - var serverConfig = new McpServerConfig() - { - Id = $"TestServer_{i}", - Name = $"TestServer_{i}", - TransportType = "ignored", - }; - - var client = await McpClientFactory.CreateAsync( - serverConfig, - createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader), - cancellationToken: cancellationToken); - - var tools = await client.ListToolsAsync(cancellationToken); + var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); Assert.Equal(11, tools.Count); McpClientTool echoTool = tools.First(t => t.Name == "Echo"); @@ -146,7 +150,11 @@ await server.RunAsync(async cancellationToken => McpClientTool doubleEchoTool = tools.First(t => t.Name == "double_echo"); Assert.Equal("double_echo", doubleEchoTool.Name); Assert.Equal("Echoes the input back to the client.", doubleEchoTool.Description); - }, TestContext.Current.CancellationToken); + } + + stdinPipe.Writer.Complete(); + await serverRunTask; + stdoutPipe.Writer.Complete(); } } @@ -168,7 +176,8 @@ public async Task Can_Be_Notified_Of_Tool_Changes() var notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken); Assert.False(notificationRead.IsCompleted); - var serverTools = _server.ServerOptions.Capabilities?.Tools?.ToolCollection; + var serverOptions = _serviceProvider.GetRequiredService>().Value; + var serverTools = serverOptions.Capabilities?.Tools?.ToolCollection; Assert.NotNull(serverTools); var newTool = McpServerTool.Create([McpServerTool(Name = "NewTool")] () => "42"); diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsTransportsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsTransportsTests.cs index 853ea933..3bd09c0e 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsTransportsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsTransportsTests.cs @@ -16,7 +16,7 @@ public void WithStdioServerTransport_Sets_Transport() builder.Object.WithStdioServerTransport(); - var transportType = services.FirstOrDefault(s => s.ServiceType == typeof(IServerTransport)); + var transportType = services.FirstOrDefault(s => s.ServiceType == typeof(ITransport)); Assert.NotNull(transportType); Assert.Equal(typeof(StdioServerTransport), transportType.ImplementationType); } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs index 8ae723ca..8eae35af 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs @@ -8,13 +8,11 @@ namespace ModelContextProtocol.Tests.Server; public class McpServerFactoryTests : LoggedTest { - private readonly Mock _serverTransport; private readonly McpServerOptions _options; public McpServerFactoryTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { - _serverTransport = new Mock(); _options = new McpServerOptions { ServerInfo = new Implementation { Name = "TestServer", Version = "1.0" }, @@ -27,23 +25,29 @@ public McpServerFactoryTests(ITestOutputHelper testOutputHelper) public async Task Create_Should_Initialize_With_Valid_Parameters() { // Arrange & Act - await using IMcpServer server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory); + await using IMcpServer server = McpServerFactory.Create(Mock.Of(), _options, LoggerFactory); // Assert Assert.NotNull(server); } [Fact] - public void Constructor_Throws_For_Null_ServerTransport() + public async Task Create_Throws_For_Null_ServerTransport() { // Arrange, Act & Assert - Assert.Throws("serverTransport", () => McpServerFactory.Create(null!, _options, LoggerFactory)); + Assert.Throws("transport", () => McpServerFactory.Create(null!, _options, LoggerFactory)); + + await Assert.ThrowsAsync("serverTransport", () => + McpServerFactory.AcceptAsync(null!, _options, LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); } [Fact] - public void Constructor_Throws_For_Null_Options() + public async Task Create_Throws_For_Null_Options() { // Arrange, Act & Assert - Assert.Throws("serverOptions", () => McpServerFactory.Create(_serverTransport.Object, null!, LoggerFactory)); + Assert.Throws("serverOptions", () => McpServerFactory.Create(Mock.Of(), null!, LoggerFactory)); + + await Assert.ThrowsAsync("serverOptions", () => + McpServerFactory.AcceptAsync(Mock.Of(), null!, LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); } } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 469d0afc..b29e0bd0 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -1,6 +1,5 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; @@ -14,16 +13,14 @@ namespace ModelContextProtocol.Tests.Server; public class McpServerTests : LoggedTest { - private readonly Mock _serverTransport; - private readonly Mock _logger; + private readonly Mock _serverTransport; private readonly McpServerOptions _options; private readonly IServiceProvider _serviceProvider; public McpServerTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { - _serverTransport = new Mock(); - _logger = new Mock(); + _serverTransport = new Mock(); _options = CreateOptions(); _serviceProvider = new ServiceCollection().BuildServiceProvider(); } @@ -88,42 +85,19 @@ public async Task RunAsync_Should_Throw_InvalidOperationException_If_Already_Run { // Arrange await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, _serviceProvider); - var task = server.RunAsync(cancellationToken: TestContext.Current.CancellationToken); + var runTask = server.RunAsync(TestContext.Current.CancellationToken); // Act & Assert - await Assert.ThrowsAsync(() => server.RunAsync(cancellationToken: TestContext.Current.CancellationToken)); + await Assert.ThrowsAsync(() => server.RunAsync(TestContext.Current.CancellationToken)); - await task; - } - - [Fact] - public async Task RunAsync_ShouldStartListening() - { - // Arrange - await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, _serviceProvider); - - // Act - await server.RunAsync(cancellationToken: TestContext.Current.CancellationToken); - - // Assert - _serverTransport.Verify(t => t.StartListeningAsync(It.IsAny()), Times.Once); - } - - [Fact] - public async Task RunAsync_Sets_Initialized_After_Transport_Responses_Initialized_Notification() - { - await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory, _serviceProvider); - - await server.RunAsync(cancellationToken: TestContext.Current.CancellationToken); - - // Send initialized notification - await transport.SendMessageAsync(new JsonRpcNotification - { - Method = "notifications/initialized" - }, TestContext.Current.CancellationToken); - - await Task.Delay(50, TestContext.Current.CancellationToken); + try + { + await runTask; + } + catch (NullReferenceException) + { + // _serverTransport.Object returns a null MessageReader + } } [Fact] @@ -147,7 +121,7 @@ public async Task RequestSamplingAsync_Should_SendRequest() await using var server = McpServerFactory.Create(transport, _options, LoggerFactory, _serviceProvider); SetClientCapabilities(server, new ClientCapabilities { Sampling = new SamplingCapability() }); - await server.RunAsync(cancellationToken: TestContext.Current.CancellationToken); + var runTask = server.RunAsync(TestContext.Current.CancellationToken); // Act var result = await server.RequestSamplingAsync(new CreateMessageRequestParams { Messages = [] }, CancellationToken.None); @@ -156,6 +130,9 @@ public async Task RequestSamplingAsync_Should_SendRequest() Assert.NotEmpty(transport.SentMessages); Assert.IsType(transport.SentMessages[0]); Assert.Equal("sampling/createMessage", ((JsonRpcRequest)transport.SentMessages[0]).Method); + + await transport.DisposeAsync(); + await runTask; } [Fact] @@ -176,7 +153,7 @@ public async Task RequestRootsAsync_Should_SendRequest() await using var transport = new TestServerTransport(); await using var server = McpServerFactory.Create(transport, _options, LoggerFactory, _serviceProvider); SetClientCapabilities(server, new ClientCapabilities { Roots = new RootsCapability() }); - await server.RunAsync(cancellationToken: TestContext.Current.CancellationToken); + var runTask = server.RunAsync(TestContext.Current.CancellationToken); // Act var result = await server.RequestRootsAsync(new ListRootsRequestParams(), CancellationToken.None); @@ -186,6 +163,9 @@ public async Task RequestRootsAsync_Should_SendRequest() Assert.NotEmpty(transport.SentMessages); Assert.IsType(transport.SentMessages[0]); Assert.Equal("roots/list", ((JsonRpcRequest)transport.SentMessages[0]).Method); + + await transport.DisposeAsync(); + await runTask; } [Fact] @@ -540,7 +520,7 @@ private async Task Can_Handle_Requests(ServerCapabilities? serverCapabilities, s await using var server = McpServerFactory.Create(transport, options, LoggerFactory, _serviceProvider); - await server.RunAsync(); + var runTask = server.RunAsync(TestContext.Current.CancellationToken); var receivedMessage = new TaskCompletionSource(); @@ -563,6 +543,9 @@ await transport.SendMessageAsync( Assert.NotNull(response.Result); assertResult(response.Result); + + await transport.DisposeAsync(); + await runTask; } private async Task Throws_Exception_If_No_Handler_Assigned(ServerCapabilities serverCapabilities, string method, string expectedError) @@ -660,7 +643,7 @@ public void AddNotificationHandler(string method, Func throw new NotImplementedException(); - public Task RunAsync(Func? onInitialized = null, CancellationToken cancellationToken = default) => + public Task RunAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException(); } } diff --git a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs index a2baf5d8..a21e3d61 100644 --- a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs @@ -9,20 +9,15 @@ namespace ModelContextProtocol.Tests; -public class SseIntegrationTests +public class SseIntegrationTests(ITestOutputHelper outputHelper) : LoggedTest(outputHelper) { [Fact] public async Task ConnectAndReceiveMessage_InMemoryServer() { // Arrange - using var loggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder => - builder.AddConsole() - .SetMinimumLevel(LogLevel.Debug)); - - await using InMemoryTestSseServer server = new(logger: loggerFactory.CreateLogger()); + await using InMemoryTestSseServer server = new(logger: LoggerFactory.CreateLogger()); await server.StartAsync(); - var defaultOptions = new McpClientOptions { ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" } @@ -41,7 +36,7 @@ public async Task ConnectAndReceiveMessage_InMemoryServer() await using var client = await McpClientFactory.CreateAsync( defaultConfig, defaultOptions, - loggerFactory: loggerFactory, + loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); // Wait for SSE connection to be established @@ -60,10 +55,6 @@ public async Task ConnectAndReceiveMessage_EverythingServerWithSse() { Assert.SkipWhen(!EverythingSseServerFixture.IsDockerAvailable, "docker is not available"); - using var loggerFactory = LoggerFactory.Create(builder => - builder.AddConsole() - .SetMinimumLevel(LogLevel.Debug)); - int port = 3001; await using var fixture = new EverythingSseServerFixture(port); @@ -87,7 +78,7 @@ public async Task ConnectAndReceiveMessage_EverythingServerWithSse() await using var client = await McpClientFactory.CreateAsync( defaultConfig, defaultOptions, - loggerFactory: loggerFactory, + loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); @@ -101,11 +92,6 @@ public async Task Sampling_Sse_EverythingServer() { Assert.SkipWhen(!EverythingSseServerFixture.IsDockerAvailable, "docker is not available"); - // arrange - using var loggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder => - builder.AddConsole() - .SetMinimumLevel(LogLevel.Debug)); - int port = 3002; await using var fixture = new EverythingSseServerFixture(port); @@ -153,7 +139,7 @@ public async Task Sampling_Sse_EverythingServer() await using var client = await McpClientFactory.CreateAsync( defaultConfig, defaultOptions, - loggerFactory: loggerFactory, + loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); // Call the server's sampleLLM tool which should trigger our sampling handler @@ -161,8 +147,7 @@ public async Task Sampling_Sse_EverythingServer() { ["prompt"] = "Test prompt", ["maxTokens"] = 100 - } -, TestContext.Current.CancellationToken); + }, TestContext.Current.CancellationToken); // assert Assert.NotNull(result); @@ -175,11 +160,7 @@ public async Task Sampling_Sse_EverythingServer() public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventUri() { // Arrange - using var loggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder => - builder.AddConsole() - .SetMinimumLevel(LogLevel.Debug)); - - await using InMemoryTestSseServer server = new(logger: loggerFactory.CreateLogger()); + await using InMemoryTestSseServer server = new(logger: LoggerFactory.CreateLogger()); server.UseFullUrlForEndpointEvent = true; await server.StartAsync(); @@ -202,7 +183,7 @@ public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventU await using var client = await McpClientFactory.CreateAsync( defaultConfig, defaultOptions, - loggerFactory: loggerFactory, + loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); // Wait for SSE connection to be established @@ -219,11 +200,7 @@ public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventU public async Task ConnectAndReceiveNotification_InMemoryServer() { // Arrange - using var loggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder => - builder.AddConsole() - .SetMinimumLevel(LogLevel.Debug)); - - await using InMemoryTestSseServer server = new(logger: loggerFactory.CreateLogger()); + await using InMemoryTestSseServer server = new(logger: LoggerFactory.CreateLogger()); await server.StartAsync(); @@ -245,7 +222,7 @@ public async Task ConnectAndReceiveNotification_InMemoryServer() await using var client = await McpClientFactory.CreateAsync( defaultConfig, defaultOptions, - loggerFactory: loggerFactory, + loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); // Wait for SSE connection to be established @@ -267,48 +244,4 @@ public async Task ConnectAndReceiveNotification_InMemoryServer() var message = await receivedNotification.Task.WaitAsync(TimeSpan.FromSeconds(10), TestContext.Current.CancellationToken); Assert.Equal("Hello from server!", message); } - - [Fact] - public async Task ConnectTwice_Throws() - { - // Arrange - using var loggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder => - builder.AddConsole() - .SetMinimumLevel(LogLevel.Debug)); - - await using InMemoryTestSseServer server = new(logger: loggerFactory.CreateLogger()); - await server.StartAsync(); - - - var defaultOptions = new McpClientOptions - { - ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" } - }; - - var defaultConfig = new McpServerConfig - { - Id = "test_server", - Name = "In-memory Test Server", - TransportType = TransportTypes.Sse, - TransportOptions = [], - Location = "http://localhost:5000/sse" - }; - - // Act - await using var client = await McpClientFactory.CreateAsync( - defaultConfig, - defaultOptions, - loggerFactory: loggerFactory, - cancellationToken: TestContext.Current.CancellationToken); - - PropertyInfo? transportProperty = client.GetType().GetProperty("Transport", BindingFlags.NonPublic | BindingFlags.Instance); - Assert.NotNull(transportProperty); - var transport = (SseClientTransport)transportProperty.GetValue(client)!; - - // Wait for SSE connection to be established - await server.WaitForConnectionAsync(TimeSpan.FromSeconds(10)); - - // Assert - await Assert.ThrowsAsync(async () => await transport.ConnectAsync(TestContext.Current.CancellationToken)); - } } diff --git a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs index b834f843..976997af 100644 --- a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs @@ -20,7 +20,8 @@ private Task GetClientAsync(McpClientOptions? options = null) return McpClientFactory.CreateAsync( _fixture.DefaultConfig, options ?? SseServerIntegrationTestFixture.CreateDefaultClientOptions(), - loggerFactory: LoggerFactory); + loggerFactory: LoggerFactory, + cancellationToken: TestContext.Current.CancellationToken); } [Fact] @@ -30,7 +31,7 @@ public async Task ConnectAndPing_Sse_TestServer() // Act var client = await GetClientAsync(); - await client.PingAsync(CancellationToken.None); + await client.PingAsync(TestContext.Current.CancellationToken); // Assert Assert.NotNull(client); @@ -75,7 +76,7 @@ public async Task CallTool_Sse_EchoServer() { ["message"] = "Hello MCP!" }, - CancellationToken.None + TestContext.Current.CancellationToken ); // assert @@ -109,7 +110,7 @@ public async Task ReadResource_Sse_TextResource() // Odd numbered resources are text in the everything server (despite the docs saying otherwise) // 1 is index 0, which is "even" in the 0-based index // We copied this oddity to the test server - var result = await client.ReadResourceAsync("test://static/resource/1", CancellationToken.None); + var result = await client.ReadResourceAsync("test://static/resource/1", TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Single(result.Contents); @@ -126,7 +127,7 @@ public async Task ReadResource_Sse_BinaryResource() // Even numbered resources are binary in the everything server (despite the docs saying otherwise) // 2 is index 1, which is "odd" in the 0-based index // We copied this oddity to the test server - var result = await client.ReadResourceAsync("test://static/resource/2", CancellationToken.None); + var result = await client.ReadResourceAsync("test://static/resource/2", TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Single(result.Contents); @@ -157,7 +158,7 @@ public async Task GetPrompt_Sse_SimplePrompt() // act var client = await GetClientAsync(); - var result = await client.GetPromptAsync("simple_prompt", null, CancellationToken.None); + var result = await client.GetPromptAsync("simple_prompt", null, TestContext.Current.CancellationToken); // assert Assert.NotNull(result); @@ -176,7 +177,7 @@ public async Task GetPrompt_Sse_ComplexPrompt() { "temperature", "0.7" }, { "style", "formal" } }; - var result = await client.GetPromptAsync("complex_prompt", arguments, CancellationToken.None); + var result = await client.GetPromptAsync("complex_prompt", arguments, TestContext.Current.CancellationToken); // assert Assert.NotNull(result); @@ -191,7 +192,7 @@ public async Task GetPrompt_Sse_NonExistent_ThrowsException() // act var client = await GetClientAsync(); await Assert.ThrowsAsync(() => - client.GetPromptAsync("non_existent_prompt", null, CancellationToken.None)); + client.GetPromptAsync("non_existent_prompt", null, TestContext.Current.CancellationToken)); } [Fact] diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs index 84f848fa..f26d31a4 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs @@ -2,8 +2,8 @@ using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Tests.Utils; +using System.IO.Pipelines; using System.Net; -using System.Reflection; namespace ModelContextProtocol.Tests.Transport; @@ -35,26 +35,6 @@ public SseClientTransportTests(ITestOutputHelper testOutputHelper) }; } - [Fact] - public async Task Constructor_Should_Initialize_With_Valid_Parameters() - { - // Act - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, LoggerFactory); - - // Assert - Assert.NotNull(transport); - - PropertyInfo? getOptions = transport.GetType().GetProperty("Options", BindingFlags.NonPublic | BindingFlags.Instance); - Assert.NotNull(getOptions); - var options = (SseClientTransportOptions)getOptions.GetValue(transport)!; - - Assert.Equal(TimeSpan.FromSeconds(2), options.ConnectionTimeout); - Assert.Equal(3, options.MaxReconnectAttempts); - Assert.Equal(TimeSpan.FromMilliseconds(50), options.ReconnectDelay); - Assert.NotNull(options.AdditionalHeaders); - Assert.Equal("header", options.AdditionalHeaders["test"]); - } - [Fact] public void Constructor_Throws_For_Null_Options() { @@ -81,75 +61,23 @@ public async Task ConnectAsync_Should_Connect_Successfully() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); + var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); bool firstCall = true; - mockHttpHandler.RequestHandler = async (request) => + mockHttpHandler.RequestHandler = (request) => { - if (!firstCall) - { - Assert.True(transport.IsConnected); - await transport.DisposeAsync(); - } - firstCall = false; - return new HttpResponseMessage + return Task.FromResult(new HttpResponseMessage { StatusCode = HttpStatusCode.OK, Content = new StringContent("event: endpoint\r\ndata: http://localhost\r\n\r\n") - }; - }; - - await transport.ConnectAsync(TestContext.Current.CancellationToken); - } - - [Fact] - public async Task ConnectAsync_Throws_If_Already_Connected() - { - using var mockHttpHandler = new MockHttpHandler(); - using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); - var tcsConnected = new TaskCompletionSource(); - var tcsDone = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var callIndex = 0; - - mockHttpHandler.RequestHandler = async (request) => - { - switch (callIndex++) - { - case 0: - return new HttpResponseMessage - { - StatusCode = HttpStatusCode.OK, - Content = new StringContent("event: endpoint\r\ndata: http://localhost\r\n\r\n") - }; - case 1: - tcsConnected.SetResult(); - await tcsDone.Task; - return new HttpResponseMessage - { - StatusCode = HttpStatusCode.OK, - Content = new StringContent("") - }; - default: - return new HttpResponseMessage - { - StatusCode = HttpStatusCode.OK, - Content = new StringContent("") - }; - } + }); }; - var task = transport.ConnectAsync(TestContext.Current.CancellationToken); - await tcsConnected.Task; - Assert.True(transport.IsConnected); - var action = async () => await transport.ConnectAsync(); - var exception = await Assert.ThrowsAsync(action); - Assert.Equal("Transport is already connected", exception.Message); - tcsDone.SetResult(); - await transport.DisposeAsync(); - await task; + await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken); + Assert.NotNull(session); + Assert.False(firstCall); } [Fact] @@ -157,7 +85,7 @@ public async Task ConnectAsync_Throws_Exception_On_Failure() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); + var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); var retries = 0; mockHttpHandler.RequestHandler = (request) => @@ -174,21 +102,12 @@ public async Task ConnectAsync_Throws_Exception_On_Failure() Assert.Equal(_transportOptions.MaxReconnectAttempts, retries); } - [Fact] - public async Task SendMessageAsync_Throws_Exception_If_MessageEndpoint_Not_Set() - { - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, LoggerFactory); - - // Assert - await Assert.ThrowsAsync(() => transport.SendMessageAsync(new JsonRpcRequest() { Method = "test" }, CancellationToken.None)); - } - [Fact] public async Task SendMessageAsync_Handles_Accepted_Response() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); + var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); var firstCall = true; mockHttpHandler.RequestHandler = (request) => @@ -216,9 +135,8 @@ public async Task SendMessageAsync_Handles_Accepted_Response() } }; - await transport.ConnectAsync(TestContext.Current.CancellationToken); - - await transport.SendMessageAsync(new JsonRpcRequest() { Method = "initialize", Id = RequestId.FromNumber(44) }, CancellationToken.None); + await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken); + await session.SendMessageAsync(new JsonRpcRequest() { Method = "initialize", Id = RequestId.FromNumber(44) }, CancellationToken.None); Assert.True(true); } @@ -227,7 +145,13 @@ public async Task SendMessageAsync_Handles_Accepted_Json_RPC_Response() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); + var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); + + var eventSourcePipe = new Pipe(); + var eventSourceData = "event: endpoint\r\ndata: /sseendpoint\r\n\r\n"u8; + Assert.True(eventSourceData.TryCopyTo(eventSourcePipe.Writer.GetSpan())); + eventSourcePipe.Writer.Advance(eventSourceData.Length); + await eventSourcePipe.Writer.FlushAsync(TestContext.Current.CancellationToken); var firstCall = true; mockHttpHandler.RequestHandler = (request) => @@ -250,15 +174,16 @@ public async Task SendMessageAsync_Handles_Accepted_Json_RPC_Response() return Task.FromResult(new HttpResponseMessage { StatusCode = HttpStatusCode.OK, - Content = new StringContent("event: endpoint\r\ndata: /sseendpoint\r\n\r\n") + Content = new StreamContent(eventSourcePipe.Reader.AsStream()), }); } }; - await transport.ConnectAsync(TestContext.Current.CancellationToken); + await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken); - await transport.SendMessageAsync(new JsonRpcRequest() { Method = "initialize", Id = RequestId.FromNumber(44) }, CancellationToken.None); + await session.SendMessageAsync(new JsonRpcRequest() { Method = "initialize", Id = RequestId.FromNumber(44) }, CancellationToken.None); Assert.True(true); + eventSourcePipe.Writer.Complete(); } [Fact] @@ -266,7 +191,7 @@ public async Task ReceiveMessagesAsync_Handles_Messages() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); + var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); var callIndex = 0; mockHttpHandler.RequestHandler = (request) => @@ -278,45 +203,39 @@ public async Task ReceiveMessagesAsync_Handles_Messages() return Task.FromResult(new HttpResponseMessage { StatusCode = HttpStatusCode.OK, - Content = new StringContent("event: endpoint\r\ndata: /sseendpoint\r\n\r\n") - }); - } - else if (callIndex == 2) - { - return Task.FromResult(new HttpResponseMessage - { - StatusCode = HttpStatusCode.OK, - Content = new StringContent("event: message\r\ndata: {\"jsonrpc\":\"2.0\", \"id\": \"44\", \"method\": \"test\", \"params\": null}\r\n\r\n") + Content = new StringContent("event: endpoint\r\ndata: /sseendpoint\r\n\r\nevent: message\r\ndata: {\"jsonrpc\":\"2.0\", \"id\": \"44\", \"method\": \"test\", \"params\": null}\r\n\r\n") }); } throw new IOException("Abort"); }; - await transport.ConnectAsync(TestContext.Current.CancellationToken); - Assert.True(transport.MessageReader.TryRead(out var message)); + await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken); + Assert.True(session.MessageReader.TryRead(out var message)); Assert.NotNull(message); Assert.IsType(message); Assert.Equal("44", ((JsonRpcRequest)message).Id.AsString); } - [Fact] - public async Task CloseAsync_Should_Dispose_Resources() - { - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, LoggerFactory); - - await transport.DisposeAsync(); - - Assert.False(transport.IsConnected); - } - [Fact] public async Task DisposeAsync_Should_Dispose_Resources() { - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, LoggerFactory); + using var mockHttpHandler = new MockHttpHandler(); + using var httpClient = new HttpClient(mockHttpHandler); + mockHttpHandler.RequestHandler = request => + { + return Task.FromResult(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent("event: endpoint\r\ndata: http://localhost\r\n\r\n") + }); + }; + + var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); + await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken); - await transport.DisposeAsync(); + await session.DisposeAsync(); - Assert.False(transport.IsConnected); + Assert.False(session.IsConnected); } } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs index 03994ee7..8daf82cd 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs @@ -53,11 +53,9 @@ public void Constructor_Throws_For_Null_Options() } [Fact] - public async Task StartListeningAsync_Should_Set_Connected_State() + public async Task Should_Start_In_Connected_State() { - await using var transport = new StdioServerTransport(_serverOptions); - - await transport.StartListeningAsync(TestContext.Current.CancellationToken); + await using var transport = new StdioServerTransport(_serverOptions.ServerInfo.Name, Stream.Null, Stream.Null, LoggerFactory); Assert.True(transport.IsConnected); } @@ -72,12 +70,7 @@ public async Task SendMessageAsync_Should_Send_Message() new Pipe().Reader.AsStream(), output, LoggerFactory); - - await transport.StartListeningAsync(TestContext.Current.CancellationToken); - - // Ensure transport is fully initialized - await Task.Delay(100, TestContext.Current.CancellationToken); - + // Verify transport is connected Assert.True(transport.IsConnected, "Transport should be connected after StartListeningAsync"); @@ -91,20 +84,10 @@ public async Task SendMessageAsync_Should_Send_Message() Assert.Equal(expected, result); } - [Fact] - public async Task SendMessageAsync_Throws_Exception_If_Not_Connected() - { - await using var transport = new StdioServerTransport(_serverOptions); - - var message = new JsonRpcRequest { Method = "test" }; - - await Assert.ThrowsAsync(() => transport.SendMessageAsync(message, TestContext.Current.CancellationToken)); - } - [Fact] public async Task DisposeAsync_Should_Dispose_Resources() { - await using var transport = new StdioServerTransport(_serverOptions); + await using var transport = new StdioServerTransport(_serverOptions.ServerInfo.Name, Stream.Null, Stream.Null, LoggerFactory); await transport.DisposeAsync(); @@ -126,12 +109,7 @@ public async Task ReadMessagesAsync_Should_Read_Messages() input, Stream.Null, LoggerFactory); - - await transport.StartListeningAsync(TestContext.Current.CancellationToken); - - // Ensure transport is fully initialized - await Task.Delay(100, TestContext.Current.CancellationToken); - + // Verify transport is connected Assert.True(transport.IsConnected, "Transport should be connected after StartListeningAsync"); @@ -150,8 +128,7 @@ public async Task ReadMessagesAsync_Should_Read_Messages() [Fact] public async Task CleanupAsync_Should_Cleanup_Resources() { - var transport = new StdioServerTransport(_serverOptions); - await transport.StartListeningAsync(TestContext.Current.CancellationToken); + var transport = new StdioServerTransport(_serverOptions.ServerInfo.Name, Stream.Null, Stream.Null, LoggerFactory); await transport.DisposeAsync(); @@ -165,24 +142,19 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters() using var output = new MemoryStream(); await using var transport = new StdioServerTransport( - _serverOptions.ServerInfo.Name, + _serverOptions.ServerInfo.Name, new Pipe().Reader.AsStream(), - output, + output, LoggerFactory); - - await transport.StartListeningAsync(TestContext.Current.CancellationToken); - - // Ensure transport is fully initialized - await Task.Delay(100, TestContext.Current.CancellationToken); - + // Verify transport is connected Assert.True(transport.IsConnected, "Transport should be connected after StartListeningAsync"); // Test 1: Chinese characters (BMP Unicode) var chineseText = "上下文伺服器"; // "Context Server" in Chinese - var chineseMessage = new JsonRpcRequest - { - Method = "test", + var chineseMessage = new JsonRpcRequest + { + Method = "test", Id = RequestId.FromNumber(44), Params = new Dictionary { @@ -193,18 +165,18 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters() // Clear output and send message output.SetLength(0); await transport.SendMessageAsync(chineseMessage, TestContext.Current.CancellationToken); - + // Verify Chinese characters preserved but encoded var chineseResult = Encoding.UTF8.GetString(output.ToArray()).Trim(); var expectedChinese = JsonSerializer.Serialize(chineseMessage, McpJsonUtilities.DefaultOptions); Assert.Equal(expectedChinese, chineseResult); Assert.Contains(JsonSerializer.Serialize(chineseText), chineseResult); - + // Test 2: Emoji (non-BMP Unicode using surrogate pairs) var emojiText = "🔍 🚀 👍"; // Magnifying glass, rocket, thumbs up - var emojiMessage = new JsonRpcRequest - { - Method = "test", + var emojiMessage = new JsonRpcRequest + { + Method = "test", Id = RequestId.FromNumber(45), Params = new Dictionary { @@ -215,23 +187,23 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters() // Clear output and send message output.SetLength(0); await transport.SendMessageAsync(emojiMessage, TestContext.Current.CancellationToken); - + // Verify emoji preserved - might be as either direct characters or escape sequences var emojiResult = Encoding.UTF8.GetString(output.ToArray()).Trim(); var expectedEmoji = JsonSerializer.Serialize(emojiMessage, McpJsonUtilities.DefaultOptions); Assert.Equal(expectedEmoji, emojiResult); - + // Verify surrogate pairs in different possible formats // Magnifying glass emoji: 🔍 (U+1F50D) - bool magnifyingGlassFound = - emojiResult.Contains("🔍") || + bool magnifyingGlassFound = + emojiResult.Contains("🔍") || emojiResult.Contains("\\ud83d\\udd0d", StringComparison.OrdinalIgnoreCase); - + // Rocket emoji: 🚀 (U+1F680) - bool rocketFound = - emojiResult.Contains("🚀") || + bool rocketFound = + emojiResult.Contains("🚀") || emojiResult.Contains("\\ud83d\\ude80", StringComparison.OrdinalIgnoreCase); - + Assert.True(magnifyingGlassFound, "Magnifying glass emoji not found in result"); Assert.True(rocketFound, "Rocket emoji not found in result"); } diff --git a/tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs b/tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs index 49c7ca7a..6966f7d8 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs @@ -23,7 +23,7 @@ public StreamClientTransport(TextWriter serverStdinWriter, TextReader serverStdo SetConnected(true); } - public Task ConnectAsync(CancellationToken cancellationToken = default) => Task.CompletedTask; + public Task ConnectAsync(CancellationToken cancellationToken = default) => Task.FromResult(this); public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) { @@ -37,22 +37,28 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio private async Task ReadMessagesAsync(CancellationToken cancellationToken) { - while (await _serverStdoutReader.ReadLineAsync(cancellationToken).ConfigureAwait(false) is string line) + try { - if (!string.IsNullOrWhiteSpace(line)) + while (await _serverStdoutReader.ReadLineAsync(cancellationToken).ConfigureAwait(false) is string line) { - try + if (!string.IsNullOrWhiteSpace(line)) { - if (JsonSerializer.Deserialize(line.Trim(), _jsonOptions) is { } message) + try + { + if (JsonSerializer.Deserialize(line.Trim(), _jsonOptions) is { } message) + { + await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); + } + } + catch (JsonException) { - await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); } - } - catch (JsonException) - { } } } + catch (OperationCanceledException) + { + } } public override async ValueTask DisposeAsync() diff --git a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs index 4fca42a3..316e0561 100644 --- a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs +++ b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs @@ -5,12 +5,11 @@ namespace ModelContextProtocol.Tests.Utils; -public class TestServerTransport : IServerTransport +public class TestServerTransport : ITransport { private readonly Channel _messageChannel; - private bool _isStarted; - public bool IsConnected => _isStarted; + public bool IsConnected { get; set; } public Task Completion => Task.CompletedTask; @@ -27,9 +26,15 @@ public TestServerTransport() SingleReader = true, SingleWriter = true, }); + IsConnected = true; } - public ValueTask DisposeAsync() => ValueTask.CompletedTask; + public ValueTask DisposeAsync() + { + _messageChannel.Writer.TryComplete(); + IsConnected = false; + return ValueTask.CompletedTask; + } public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) { @@ -51,12 +56,6 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca OnMessageSent?.Invoke(message); } - public Task StartListeningAsync(CancellationToken cancellationToken = default) - { - _isStarted = true; - return Task.CompletedTask; - } - private async Task ListRoots(JsonRpcRequest request, CancellationToken cancellationToken) { await WriteMessageAsync(new JsonRpcResponse @@ -78,7 +77,7 @@ await WriteMessageAsync(new JsonRpcResponse }, cancellationToken); } - protected async Task WriteMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + private async Task WriteMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) { await _messageChannel.Writer.WriteAsync(message, cancellationToken); } From 94d5f18a6f5ba62030a97ab98f1cc560c82d1feb Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Sat, 29 Mar 2025 15:35:39 -0700 Subject: [PATCH 04/10] Make McpSession field private --- src/ModelContextProtocol/Server/McpServer.cs | 5 +---- src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index bcfa0a0e..9d4d0b3b 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -72,7 +72,7 @@ private McpServer(McpServerOptions options, ILoggerFactory? loggerFactory, IServ _toolsChangedDelegate = delegate { - _ = _session?.SendMessageAsync(new JsonRpcNotification() + _ = SendMessageAsync(new JsonRpcNotification() { Method = NotificationMethods.ToolListChangedNotification, }); @@ -141,9 +141,6 @@ public async Task AcceptSessionAsync(CancellationToken cancellationToken = defau /// public async Task RunAsync(CancellationToken cancellationToken = default) { - // Below is effectively an assertion. The McpServerFactory should not return before the _transport is initialized. - Throw.IfNull(_session); - try { // Start processing messages diff --git a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs b/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs index 971f7dce..acf8b78a 100644 --- a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs +++ b/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs @@ -19,12 +19,12 @@ internal abstract class McpJsonRpcEndpoint : IAsyncDisposable private readonly RequestHandlers _requestHandlers = []; private readonly NotificationHandlers _notificationHandlers = []; + private McpSession? _session; private CancellationTokenSource? _sessionCts; private int _started; private int _disposed; protected readonly ILogger _logger; - protected McpSession? _session; /// /// Initializes a new instance of the class. From f2d033d920e3428831eb42e2c02f9b195598561b Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Sat, 29 Mar 2025 16:02:23 -0700 Subject: [PATCH 05/10] Try to reduce test flakiness --- .../McpEndpointRouteBuilderExtensions.cs | 6 ++-- .../Transport/StdioClientStreamTransport.cs | 5 +++- .../ClientIntegrationTests.cs | 28 +++++++++---------- .../Server/McpServerTests.cs | 2 +- .../SseServerIntegrationTestFixture.cs | 20 +++++-------- .../SseServerIntegrationTests.cs | 6 ++++ .../Utils/DelegatingTestOutputHelper.cs | 13 +++++++++ .../Utils/LoggedTest.cs | 23 +++++++++++---- 8 files changed, 65 insertions(+), 38 deletions(-) create mode 100644 tests/ModelContextProtocol.Tests/Utils/DelegatingTestOutputHelper.cs diff --git a/samples/AspNetCoreSseServer/McpEndpointRouteBuilderExtensions.cs b/samples/AspNetCoreSseServer/McpEndpointRouteBuilderExtensions.cs index 18ce5d03..22306014 100644 --- a/samples/AspNetCoreSseServer/McpEndpointRouteBuilderExtensions.cs +++ b/samples/AspNetCoreSseServer/McpEndpointRouteBuilderExtensions.cs @@ -26,9 +26,9 @@ public static IEndpointConventionBuilder MapMcpSse(this IEndpointRouteBuilder en try { - var serverTask = server.RunAsync(cancellationToken: requestAborted); - await transport.RunAsync(cancellationToken: requestAborted); - await serverTask; + var transportTask = transport.RunAsync(cancellationToken: requestAborted); + await server.RunAsync(cancellationToken: requestAborted); + await transportTask; } catch (OperationCanceledException) when (requestAborted.IsCancellationRequested) { diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs index 8de9aa76..fc213474 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs @@ -20,6 +20,7 @@ internal sealed class StdioClientStreamTransport : TransportBase private readonly McpServerConfig _serverConfig; private readonly ILogger _logger; private readonly JsonSerializerOptions _jsonOptions; + private readonly DataReceivedEventHandler _logProcessErrors; private Process? _process; private Task? _readTask; private CancellationTokenSource? _shutdownCts; @@ -42,6 +43,7 @@ public StdioClientStreamTransport(StdioClientTransportOptions options, McpServer _options = options; _serverConfig = serverConfig; _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; + _logProcessErrors = (sender, args) => _logger.TransportError(EndpointName, args.Data ?? "(no data)"); _jsonOptions = McpJsonUtilities.DefaultOptions; } @@ -98,7 +100,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) _process = new Process { StartInfo = startInfo }; // Set up error logging - _process.ErrorDataReceived += (sender, args) => _logger.TransportError(EndpointName, args.Data ?? "(no data)"); + _process.ErrorDataReceived += _logProcessErrors; // We need both stdin and stdout to use a no-BOM UTF-8 encoding. On .NET Core, // we can use ProcessStartInfo.StandardOutputEncoding/StandardInputEncoding, but @@ -277,6 +279,7 @@ private async Task CleanupAsync(CancellationToken cancellationToken) } finally { + process.ErrorDataReceived -= _logProcessErrors; process.Dispose(); _process = null; } diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index 9ecfdd57..8b8be804 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -39,7 +39,7 @@ public async Task ConnectAndPing_Stdio(string clientId) // Act await using var client = await _fixture.CreateClientAsync(clientId); - await client.PingAsync(CancellationToken.None); + await client.PingAsync(TestContext.Current.CancellationToken); // Assert Assert.NotNull(client); @@ -89,7 +89,7 @@ public async Task CallTool_Stdio_EchoServer(string clientId) { ["message"] = "Hello MCP!" }, - CancellationToken.None + TestContext.Current.CancellationToken ); // assert @@ -141,7 +141,7 @@ public async Task GetPrompt_Stdio_SimplePrompt(string clientId) // act await using var client = await _fixture.CreateClientAsync(clientId); - var result = await client.GetPromptAsync("simple_prompt", null, CancellationToken.None); + var result = await client.GetPromptAsync("simple_prompt", null, TestContext.Current.CancellationToken); // assert Assert.NotNull(result); @@ -161,7 +161,7 @@ public async Task GetPrompt_Stdio_ComplexPrompt(string clientId) { "temperature", "0.7" }, { "style", "formal" } }; - var result = await client.GetPromptAsync("complex_prompt", arguments, CancellationToken.None); + var result = await client.GetPromptAsync("complex_prompt", arguments, TestContext.Current.CancellationToken); // assert Assert.NotNull(result); @@ -177,7 +177,7 @@ public async Task GetPrompt_NonExistent_ThrowsException(string clientId) // act await using var client = await _fixture.CreateClientAsync(clientId); await Assert.ThrowsAsync(() => - client.GetPromptAsync("non_existent_prompt", null, CancellationToken.None)); + client.GetPromptAsync("non_existent_prompt", null, TestContext.Current.CancellationToken)); } [Theory] @@ -220,7 +220,7 @@ public async Task ReadResource_Stdio_TextResource(string clientId) await using var client = await _fixture.CreateClientAsync(clientId); // Odd numbered resources are text in the everything server (despite the docs saying otherwise) // 1 is index 0, which is "even" in the 0-based index - var result = await client.ReadResourceAsync("test://static/resource/1", CancellationToken.None); + var result = await client.ReadResourceAsync("test://static/resource/1", TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Single(result.Contents); @@ -237,7 +237,7 @@ public async Task ReadResource_Stdio_BinaryResource(string clientId) await using var client = await _fixture.CreateClientAsync(clientId); // Even numbered resources are binary in the everything server (despite the docs saying otherwise) // 2 is index 1, which is "odd" in the 0-based index - var result = await client.ReadResourceAsync("test://static/resource/2", CancellationToken.None); + var result = await client.ReadResourceAsync("test://static/resource/2", TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Single(result.Contents); @@ -260,7 +260,7 @@ public async Task SubscribeResource_Stdio() tcs.TrySetResult(true); return Task.CompletedTask; }); - await client.SubscribeToResourceAsync("test://static/resource/1", CancellationToken.None); + await client.SubscribeToResourceAsync("test://static/resource/1", TestContext.Current.CancellationToken); await tcs.Task; } @@ -281,13 +281,13 @@ public async Task UnsubscribeResource_Stdio() receivedNotification.TrySetResult(true); return Task.CompletedTask; }); - await client.SubscribeToResourceAsync("test://static/resource/1", CancellationToken.None); + await client.SubscribeToResourceAsync("test://static/resource/1", TestContext.Current.CancellationToken); // wait until we received a notification await receivedNotification.Task; // unsubscribe - await client.UnsubscribeFromResourceAsync("test://static/resource/1", CancellationToken.None); + await client.UnsubscribeFromResourceAsync("test://static/resource/1", TestContext.Current.CancellationToken); receivedNotification = new(); // wait a bit to validate we don't receive another. this is best effort only; @@ -309,7 +309,7 @@ public async Task GetCompletion_Stdio_ResourceReference(string clientId) Uri = "test://static/resource/1" }, "argument_name", "1", - CancellationToken.None + TestContext.Current.CancellationToken ); Assert.NotNull(result); @@ -331,7 +331,7 @@ public async Task GetCompletion_Stdio_PromptReference(string clientId) Name = "irrelevant" }, argumentName: "style", argumentValue: "fo", - CancellationToken.None + TestContext.Current.CancellationToken ); Assert.NotNull(result); @@ -411,7 +411,7 @@ public async Task Sampling_Stdio(string clientId) // }); // // Connect - // await client.ConnectAsync(CancellationToken.None); + // await client.ConnectAsync(TestContext.Current.CancellationToken); // // assert // // nothing to assert, no servers implement roots, so we if no exception is thrown, it's a success @@ -560,7 +560,7 @@ public async Task SetLoggingLevel_ReceivesLoggingMessages(string clientId) }); // act - await client.SetLoggingLevel(LoggingLevel.Debug, CancellationToken.None); + await client.SetLoggingLevel(LoggingLevel.Debug, TestContext.Current.CancellationToken); // assert await receivedNotification.Task; diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index b29e0bd0..12b8f525 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -538,7 +538,7 @@ await transport.SendMessageAsync( } ); - var response = await receivedMessage.Task.WaitAsync(TimeSpan.FromSeconds(1)); + var response = await receivedMessage.Task.WaitAsync(TimeSpan.FromSeconds(5)); Assert.NotNull(response); Assert.NotNull(response.Result); diff --git a/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs b/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs index 75c727d1..16df076a 100644 --- a/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs @@ -3,6 +3,7 @@ using ModelContextProtocol.Configuration; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Test.Utils; +using ModelContextProtocol.Tests.Utils; using ModelContextProtocol.TestSseServer; namespace ModelContextProtocol.Tests; @@ -47,6 +48,11 @@ public void Initialize(ITestOutputHelper output) _delegatingTestOutputHelper.CurrentTestOutputHelper = output; } + public void TestCompleted() + { + _delegatingTestOutputHelper.CurrentTestOutputHelper = null; + } + public async ValueTask DisposeAsync() { _delegatingTestOutputHelper.CurrentTestOutputHelper = null; @@ -61,16 +67,4 @@ public async ValueTask DisposeAsync() _redirectingLoggerFactory.Dispose(); _stopCts.Dispose(); } - - private class DelegatingTestOutputHelper() : ITestOutputHelper - { - public ITestOutputHelper? CurrentTestOutputHelper { get; set; } - - public string Output => CurrentTestOutputHelper?.Output ?? string.Empty; - - public void Write(string message) => CurrentTestOutputHelper?.Write(message); - public void Write(string format, params object[] args) => CurrentTestOutputHelper?.Write(format, args); - public void WriteLine(string message) => CurrentTestOutputHelper?.WriteLine(message); - public void WriteLine(string format, params object[] args) => CurrentTestOutputHelper?.WriteLine(format, args); - } -} \ No newline at end of file +} diff --git a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs index 976997af..82b1e637 100644 --- a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs @@ -15,6 +15,12 @@ public SseServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestO _fixture.Initialize(testOutputHelper); } + public override void Dispose() + { + _fixture.TestCompleted(); + base.Dispose(); + } + private Task GetClientAsync(McpClientOptions? options = null) { return McpClientFactory.CreateAsync( diff --git a/tests/ModelContextProtocol.Tests/Utils/DelegatingTestOutputHelper.cs b/tests/ModelContextProtocol.Tests/Utils/DelegatingTestOutputHelper.cs new file mode 100644 index 00000000..667119af --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Utils/DelegatingTestOutputHelper.cs @@ -0,0 +1,13 @@ +namespace ModelContextProtocol.Tests.Utils; + +public class DelegatingTestOutputHelper() : ITestOutputHelper +{ + public ITestOutputHelper? CurrentTestOutputHelper { get; set; } + + public string Output => CurrentTestOutputHelper?.Output ?? string.Empty; + + public void Write(string message) => CurrentTestOutputHelper?.Write(message); + public void Write(string format, params object[] args) => CurrentTestOutputHelper?.Write(format, args); + public void WriteLine(string message) => CurrentTestOutputHelper?.WriteLine(message); + public void WriteLine(string format, params object[] args) => CurrentTestOutputHelper?.WriteLine(format, args); +} diff --git a/tests/ModelContextProtocol.Tests/Utils/LoggedTest.cs b/tests/ModelContextProtocol.Tests/Utils/LoggedTest.cs index 3cf83ac4..f2e61f76 100644 --- a/tests/ModelContextProtocol.Tests/Utils/LoggedTest.cs +++ b/tests/ModelContextProtocol.Tests/Utils/LoggedTest.cs @@ -3,16 +3,27 @@ namespace ModelContextProtocol.Tests.Utils; -public class LoggedTest(ITestOutputHelper testOutputHelper) +public class LoggedTest : IDisposable { - public ITestOutputHelper TestOutputHelper { get; } = testOutputHelper; - public ILoggerFactory LoggerFactory { get; } = CreateLoggerFactory(testOutputHelper); + private readonly DelegatingTestOutputHelper _delegatingTestOutputHelper; - private static ILoggerFactory CreateLoggerFactory(ITestOutputHelper testOutputHelper) + public LoggedTest(ITestOutputHelper testOutputHelper) { - return Microsoft.Extensions.Logging.LoggerFactory.Create(builder => + _delegatingTestOutputHelper = new() { - builder.AddProvider(new XunitLoggerProvider(testOutputHelper)); + CurrentTestOutputHelper = testOutputHelper, + }; + LoggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder => + { + builder.AddProvider(new XunitLoggerProvider(_delegatingTestOutputHelper)); }); } + + public ITestOutputHelper TestOutputHelper => _delegatingTestOutputHelper; + public ILoggerFactory LoggerFactory { get; } + + public virtual void Dispose() + { + _delegatingTestOutputHelper.CurrentTestOutputHelper = null; + } } From 6b95e13511825c69e72f323c1e704408ba2638ab Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Sat, 29 Mar 2025 17:29:52 -0700 Subject: [PATCH 06/10] Fix ownsHttpClient --- .../Transport/HttpListenerServerProvider.cs | 2 -- .../Transport/SseClientSessionTransport.cs | 10 +--------- .../Protocol/Transport/SseClientTransport.cs | 15 +++++++++++++-- .../Transport/SseClientTransportTests.cs | 12 ++++++------ .../Transport/StdioServerTransportTests.cs | 3 +++ 5 files changed, 23 insertions(+), 19 deletions(-) diff --git a/src/ModelContextProtocol/Protocol/Transport/HttpListenerServerProvider.cs b/src/ModelContextProtocol/Protocol/Transport/HttpListenerServerProvider.cs index 2071a716..73a1dbd3 100644 --- a/src/ModelContextProtocol/Protocol/Transport/HttpListenerServerProvider.cs +++ b/src/ModelContextProtocol/Protocol/Transport/HttpListenerServerProvider.cs @@ -1,7 +1,5 @@ using System.Net; -#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously - namespace ModelContextProtocol.Protocol.Transport; /// diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs index c3196b08..ac96dc02 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs @@ -27,7 +27,6 @@ internal sealed class SseClientSessionTransport : TransportBase private readonly McpServerConfig _serverConfig; private readonly JsonSerializerOptions _jsonOptions; private readonly TaskCompletionSource _connectionEstablished; - private readonly bool _ownsHttpClient; private string EndpointName => $"Client (SSE) for ({_serverConfig.Id}: {_serverConfig.Name})"; @@ -39,8 +38,7 @@ internal sealed class SseClientSessionTransport : TransportBase /// The configuration object indicating which server to connect to. /// The HTTP client instance used for requests. /// Logger factory for creating loggers. - /// True to dispose HTTP client on close connection. - public SseClientSessionTransport(SseClientTransportOptions transportOptions, McpServerConfig serverConfig, HttpClient httpClient, ILoggerFactory? loggerFactory, bool ownsHttpClient = false) + public SseClientSessionTransport(SseClientTransportOptions transportOptions, McpServerConfig serverConfig, HttpClient httpClient, ILoggerFactory? loggerFactory) : base(loggerFactory) { Throw.IfNull(transportOptions); @@ -55,7 +53,6 @@ public SseClientSessionTransport(SseClientTransportOptions transportOptions, Mcp _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; _jsonOptions = McpJsonUtilities.DefaultOptions; _connectionEstablished = new TaskCompletionSource(); - _ownsHttpClient = ownsHttpClient; } /// @@ -180,11 +177,6 @@ public override async ValueTask DisposeAsync() try { await CloseAsync(); - - if (_ownsHttpClient) - { - _httpClient?.Dispose(); - } } catch (Exception) { diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs index 79db9a1d..c136e162 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs @@ -7,7 +7,7 @@ namespace ModelContextProtocol.Protocol.Transport; /// /// The ServerSideEvents client transport implementation /// -public sealed class SseClientTransport : IClientTransport +public sealed class SseClientTransport : IClientTransport, IAsyncDisposable { private readonly SseClientTransportOptions _options; private readonly McpServerConfig _serverConfig; @@ -52,7 +52,7 @@ public SseClientTransport(SseClientTransportOptions transportOptions, McpServerC /// public async Task ConnectAsync(CancellationToken cancellationToken = default) { - var sessionTransport = new SseClientSessionTransport(_options, _serverConfig, _httpClient, _loggerFactory, _ownsHttpClient); + var sessionTransport = new SseClientSessionTransport(_options, _serverConfig, _httpClient, _loggerFactory); try { @@ -65,4 +65,15 @@ public async Task ConnectAsync(CancellationToken cancellationToken = throw; } } + + /// + public ValueTask DisposeAsync() + { + if (_ownsHttpClient) + { + _httpClient.Dispose(); + } + + return default; + } } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs index f26d31a4..ca4a5363 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs @@ -61,7 +61,7 @@ public async Task ConnectAsync_Should_Connect_Successfully() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); + await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); bool firstCall = true; @@ -85,7 +85,7 @@ public async Task ConnectAsync_Throws_Exception_On_Failure() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); + await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); var retries = 0; mockHttpHandler.RequestHandler = (request) => @@ -107,7 +107,7 @@ public async Task SendMessageAsync_Handles_Accepted_Response() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); + await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); var firstCall = true; mockHttpHandler.RequestHandler = (request) => @@ -145,7 +145,7 @@ public async Task SendMessageAsync_Handles_Accepted_Json_RPC_Response() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); + await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); var eventSourcePipe = new Pipe(); var eventSourceData = "event: endpoint\r\ndata: /sseendpoint\r\n\r\n"u8; @@ -191,7 +191,7 @@ public async Task ReceiveMessagesAsync_Handles_Messages() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); + await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); var callIndex = 0; mockHttpHandler.RequestHandler = (request) => @@ -231,7 +231,7 @@ public async Task DisposeAsync_Should_Dispose_Resources() }); }; - var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); + await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken); await session.DisposeAsync(); diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs index 8daf82cd..2fd92343 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs @@ -6,6 +6,7 @@ using ModelContextProtocol.Tests.Utils; using ModelContextProtocol.Utils.Json; using System.IO.Pipelines; +using System.Runtime.InteropServices; using System.Text; using System.Text.Json; @@ -34,6 +35,8 @@ public StdioServerTransportTests(ITestOutputHelper testOutputHelper) [Fact] public async Task Constructor_Should_Initialize_With_Valid_Parameters() { + Assert.SkipWhen(RuntimeInformation.IsOSPlatform(OSPlatform.Linux), "https://github.com/modelcontextprotocol/csharp-sdk/issues/143"); + // Act await using var transport = new StdioServerTransport(_serverOptions); From 1d44cec34315622a8df54a5a5ec044a4be3fde71 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Sun, 30 Mar 2025 14:04:45 -0700 Subject: [PATCH 07/10] Address PR feedback - Use SemaphoreSlim in McpJsonRpcEndpoint - Dispose IClientTransport on connection failure - Test cleanup --- src/ModelContextProtocol/Client/McpClient.cs | 16 ++++------ .../Client/McpClientFactory.cs | 26 +++++++++++---- src/ModelContextProtocol/Server/McpServer.cs | 4 +-- .../Shared/McpJsonRpcEndpoint.cs | 32 ++++++++++++++----- src/ModelContextProtocol/Shared/McpSession.cs | 2 +- .../Client/McpClientExtensionsTests.cs | 2 +- .../McpServerBuilderExtensionsToolsTests.cs | 6 ++-- .../Transport/StreamClientTransport.cs | 6 ++-- .../Utils/TestServerTransport.cs | 2 -- 9 files changed, 61 insertions(+), 35 deletions(-) diff --git a/src/ModelContextProtocol/Client/McpClient.cs b/src/ModelContextProtocol/Client/McpClient.cs index 4ba44d72..b20abbbb 100644 --- a/src/ModelContextProtocol/Client/McpClient.cs +++ b/src/ModelContextProtocol/Client/McpClient.cs @@ -18,7 +18,6 @@ internal sealed class McpClient : McpJsonRpcEndpoint, IMcpClient private ITransport? _sessionTransport; private CancellationTokenSource? _connectCts; - private int _disposed; /// /// Initializes a new instance of the class. @@ -142,20 +141,19 @@ await SendMessageAsync( } /// - public override async ValueTask DisposeAsync() + public override async ValueTask DisposeUnsynchronizedAsync() { - if (Interlocked.Exchange(ref _disposed, 1) != 0) - { - // TODO: It's more correct to await the last DisposeAsync before returning if it's still ongoing. - return; - } - if (_connectCts is not null) { await _connectCts.CancelAsync().ConfigureAwait(false); } - await base.DisposeAsync().ConfigureAwait(false); + await base.DisposeUnsynchronizedAsync().ConfigureAwait(false); + + if (_sessionTransport is not null) + { + await _sessionTransport.DisposeAsync().ConfigureAwait(false); + } _connectCts?.Dispose(); } diff --git a/src/ModelContextProtocol/Client/McpClientFactory.cs b/src/ModelContextProtocol/Client/McpClientFactory.cs index 08153faa..353ef1f6 100644 --- a/src/ModelContextProtocol/Client/McpClientFactory.cs +++ b/src/ModelContextProtocol/Client/McpClientFactory.cs @@ -64,17 +64,31 @@ public static async Task CreateAsync( var transport = createTransportFunc(serverConfig, loggerFactory) ?? throw new InvalidOperationException($"{nameof(createTransportFunc)} returned a null transport."); - - McpClient client = new(transport, clientOptions, serverConfig, loggerFactory); try { - await client.ConnectAsync(cancellationToken).ConfigureAwait(false); - logger.ClientCreated(endpointName); - return client; + McpClient client = new(transport, clientOptions, serverConfig, loggerFactory); + try + { + await client.ConnectAsync(cancellationToken).ConfigureAwait(false); + logger.ClientCreated(endpointName); + return client; + } + catch + { + await client.DisposeAsync().ConfigureAwait(false); + throw; + } } catch { - await client.DisposeAsync().ConfigureAwait(false); + if (transport is IAsyncDisposable asyncDisposableTransport) + { + await asyncDisposableTransport.DisposeAsync().ConfigureAwait(false); + } + else if (transport is IDisposable disposableTransport) + { + disposableTransport.Dispose(); + } throw; } } diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index 9d4d0b3b..979c38d0 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -153,14 +153,14 @@ public async Task RunAsync(CancellationToken cancellationToken = default) } } - public override async ValueTask DisposeAsync() + public override async ValueTask DisposeUnsynchronizedAsync() { if (ServerOptions.Capabilities?.Tools?.ToolCollection is { } tools) { tools.Changed -= _toolsChangedDelegate; } - await base.DisposeAsync().ConfigureAwait(false); + await base.DisposeUnsynchronizedAsync().ConfigureAwait(false); if (_serverTransport is not null && _sessionTransport is not null) { diff --git a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs b/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs index acf8b78a..aa06a101 100644 --- a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs +++ b/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs @@ -22,7 +22,9 @@ internal abstract class McpJsonRpcEndpoint : IAsyncDisposable private McpSession? _session; private CancellationTokenSource? _sessionCts; private int _started; - private int _disposed; + + private readonly SemaphoreSlim _disposeLock = new(1); + private bool _disposed; protected readonly ILogger _logger; @@ -74,18 +76,32 @@ protected void StartSession(CancellationToken fullSessionCancellationToken = def MessageProcessingTask = GetSessionOrThrow().ProcessMessagesAsync(_sessionCts.Token); } + public async ValueTask DisposeAsync() + { + await _disposeLock.WaitAsync().ConfigureAwait(false); + try + { + if (_disposed) + { + return; + } + _disposed = true; + + await DisposeUnsynchronizedAsync().ConfigureAwait(false); + } + finally + { + _disposeLock.Release(); + } + } + /// /// Cleans up the endpoint and releases resources. /// /// - public virtual async ValueTask DisposeAsync() + public virtual async ValueTask DisposeUnsynchronizedAsync() { - if (Interlocked.Exchange(ref _disposed, 1) != 0) - { - // TODO: It's more correct to await the last DisposeAsync before returning if it's still ongoing. - return; - } - + // Both McpClient and McpServer guard this with a semaphore _logger.CleaningUpEndpoint(EndpointName); if (_sessionCts is not null) diff --git a/src/ModelContextProtocol/Shared/McpSession.cs b/src/ModelContextProtocol/Shared/McpSession.cs index 1d9c8d0c..187bd8db 100644 --- a/src/ModelContextProtocol/Shared/McpSession.cs +++ b/src/ModelContextProtocol/Shared/McpSession.cs @@ -18,7 +18,7 @@ internal sealed class McpSession : IDisposable { private readonly ITransport _transport; private readonly RequestHandlers _requestHandlers; - private readonly NotificationHandlers _notificationHandlers = []; + private readonly NotificationHandlers _notificationHandlers; private readonly ConcurrentDictionary> _pendingRequests = []; private readonly JsonSerializerOptions _jsonOptions; diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs index 591e3182..51583583 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs @@ -67,7 +67,7 @@ private async Task CreateMcpClientForServer() return await McpClientFactory.CreateAsync( serverConfig, - createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader), + createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader, LoggerFactory), loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); } diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index edab8725..700910e5 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -53,6 +53,7 @@ public async ValueTask DisposeAsync() await _serviceProvider.DisposeAsync(); _cts.Dispose(); + Dispose(); } private async Task CreateMcpClientForServer() @@ -69,7 +70,7 @@ private async Task CreateMcpClientForServer() return await McpClientFactory.CreateAsync( serverConfig, - createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader), + createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader, LoggerFactory), loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); } @@ -104,7 +105,6 @@ public async Task Can_List_Registered_Tools() Assert.Equal("Echoes the input back to the client.", doubleEchoTool.Description); } - [Fact] public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_Tools() { @@ -132,7 +132,7 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T await using (var client = await McpClientFactory.CreateAsync( serverConfig, - createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader), + createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader, LoggerFactory), loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)) { diff --git a/tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs b/tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs index 6966f7d8..d41f0b97 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs @@ -1,4 +1,4 @@ -using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Utils.Json; @@ -14,8 +14,8 @@ internal sealed class StreamClientTransport : TransportBase, IClientTransport private readonly TextReader _serverStdoutReader; private readonly TextWriter _serverStdinWriter; - public StreamClientTransport(TextWriter serverStdinWriter, TextReader serverStdoutReader) - : base(NullLoggerFactory.Instance) + public StreamClientTransport(TextWriter serverStdinWriter, TextReader serverStdoutReader, ILoggerFactory loggerFactory) + : base(loggerFactory) { _serverStdoutReader = serverStdoutReader; _serverStdinWriter = serverStdinWriter; diff --git a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs index 316e0561..16c9ac86 100644 --- a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs +++ b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs @@ -11,8 +11,6 @@ public class TestServerTransport : ITransport public bool IsConnected { get; set; } - public Task Completion => Task.CompletedTask; - public ChannelReader MessageReader => _messageChannel; public List SentMessages { get; } = []; From cba8523b2d355c4ae2331c2089e7e62fdd878203 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Sun, 30 Mar 2025 14:11:56 -0700 Subject: [PATCH 08/10] Fix flaky Should_Start_In_Connected_State test --- src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs | 1 - .../Transport/StdioServerTransportTests.cs | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs b/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs index aa06a101..e7a29f63 100644 --- a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs +++ b/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs @@ -101,7 +101,6 @@ public async ValueTask DisposeAsync() /// public virtual async ValueTask DisposeUnsynchronizedAsync() { - // Both McpClient and McpServer guard this with a semaphore _logger.CleaningUpEndpoint(EndpointName); if (_sessionCts is not null) diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs index 2fd92343..2801d133 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs @@ -58,7 +58,7 @@ public void Constructor_Throws_For_Null_Options() [Fact] public async Task Should_Start_In_Connected_State() { - await using var transport = new StdioServerTransport(_serverOptions.ServerInfo.Name, Stream.Null, Stream.Null, LoggerFactory); + await using var transport = new StdioServerTransport(_serverOptions.ServerInfo.Name, new Pipe().Reader.AsStream(), Stream.Null, LoggerFactory); Assert.True(transport.IsConnected); } From d055517b9ee8238a418772c61e30e69f60f061a1 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Sun, 30 Mar 2025 14:43:13 -0700 Subject: [PATCH 09/10] Fix whitespace --- src/ModelContextProtocol/Client/McpClientFactory.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/ModelContextProtocol/Client/McpClientFactory.cs b/src/ModelContextProtocol/Client/McpClientFactory.cs index 353ef1f6..f50329f8 100644 --- a/src/ModelContextProtocol/Client/McpClientFactory.cs +++ b/src/ModelContextProtocol/Client/McpClientFactory.cs @@ -64,6 +64,7 @@ public static async Task CreateAsync( var transport = createTransportFunc(serverConfig, loggerFactory) ?? throw new InvalidOperationException($"{nameof(createTransportFunc)} returned a null transport."); + try { McpClient client = new(transport, clientOptions, serverConfig, loggerFactory); From afe100c419102307122ba8aef52b0505bbcab422 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Sun, 30 Mar 2025 17:02:38 -0700 Subject: [PATCH 10/10] Address more PR feedback --- src/ModelContextProtocol/Client/McpClient.cs | 15 +++++++---- .../McpServerMultiSessionHostedService.cs | 2 +- .../McpServerSingleSessionHostedService.cs | 2 +- .../HttpListenerSseServerTransport.cs | 12 ++++++++- src/ModelContextProtocol/Server/McpServer.cs | 15 +++++++---- .../Shared/McpJsonRpcEndpoint.cs | 25 ++++++++----------- 6 files changed, 43 insertions(+), 28 deletions(-) diff --git a/src/ModelContextProtocol/Client/McpClient.cs b/src/ModelContextProtocol/Client/McpClient.cs index b20abbbb..b774389f 100644 --- a/src/ModelContextProtocol/Client/McpClient.cs +++ b/src/ModelContextProtocol/Client/McpClient.cs @@ -148,13 +148,18 @@ public override async ValueTask DisposeUnsynchronizedAsync() await _connectCts.CancelAsync().ConfigureAwait(false); } - await base.DisposeUnsynchronizedAsync().ConfigureAwait(false); - - if (_sessionTransport is not null) + try { - await _sessionTransport.DisposeAsync().ConfigureAwait(false); + await base.DisposeUnsynchronizedAsync().ConfigureAwait(false); } + finally + { + if (_sessionTransport is not null) + { + await _sessionTransport.DisposeAsync().ConfigureAwait(false); + } - _connectCts?.Dispose(); + _connectCts?.Dispose(); + } } } diff --git a/src/ModelContextProtocol/Hosting/McpServerMultiSessionHostedService.cs b/src/ModelContextProtocol/Hosting/McpServerMultiSessionHostedService.cs index e465a725..456f8944 100644 --- a/src/ModelContextProtocol/Hosting/McpServerMultiSessionHostedService.cs +++ b/src/ModelContextProtocol/Hosting/McpServerMultiSessionHostedService.cs @@ -9,7 +9,7 @@ namespace ModelContextProtocol.Hosting; /// /// Hosted service for a multi-session (i.e. HTTP) MCP server. /// -internal class McpServerMultiSessionHostedService : BackgroundService +internal sealed class McpServerMultiSessionHostedService : BackgroundService { private readonly IServerTransport _serverTransport; private readonly McpServerOptions _serverOptions; diff --git a/src/ModelContextProtocol/Hosting/McpServerSingleSessionHostedService.cs b/src/ModelContextProtocol/Hosting/McpServerSingleSessionHostedService.cs index a59d783f..e86c288e 100644 --- a/src/ModelContextProtocol/Hosting/McpServerSingleSessionHostedService.cs +++ b/src/ModelContextProtocol/Hosting/McpServerSingleSessionHostedService.cs @@ -6,7 +6,7 @@ namespace ModelContextProtocol.Hosting; /// /// Hosted service for a single-session (i.e stdio) MCP server. /// -internal class McpServerSingleSessionHostedService(IMcpServer session) : BackgroundService +internal sealed class McpServerSingleSessionHostedService(IMcpServer session) : BackgroundService { /// protected override Task ExecuteAsync(CancellationToken stoppingToken) => session.RunAsync(stoppingToken); diff --git a/src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerTransport.cs index 614be147..4e0aa7ed 100644 --- a/src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerTransport.cs @@ -33,7 +33,7 @@ public sealed class HttpListenerSseServerTransport : IServerTransport, IAsyncDis /// The port to listen on. /// A logger factory for creating loggers. public HttpListenerSseServerTransport(McpServerOptions serverOptions, int port, ILoggerFactory loggerFactory) - : this(serverOptions?.ServerInfo?.Name!, port, loggerFactory) + : this(GetServerName(serverOptions), port, loggerFactory) { } @@ -167,4 +167,14 @@ private async Task OnMessageAsync(Stream requestStream, CancellationToken return false; } } + + /// Validates the and extracts from it the server name to use. + private static string GetServerName(McpServerOptions serverOptions) + { + Throw.IfNull(serverOptions); + Throw.IfNull(serverOptions.ServerInfo); + Throw.IfNull(serverOptions.ServerInfo.Name); + + return serverOptions.ServerInfo.Name; + } } diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index 979c38d0..af2c6644 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -160,12 +160,17 @@ public override async ValueTask DisposeUnsynchronizedAsync() tools.Changed -= _toolsChangedDelegate; } - await base.DisposeUnsynchronizedAsync().ConfigureAwait(false); - - if (_serverTransport is not null && _sessionTransport is not null) + try { - // We created the _sessionTransport from the _serverTransport, so we own it. - await _sessionTransport.DisposeAsync().ConfigureAwait(false); + await base.DisposeUnsynchronizedAsync().ConfigureAwait(false); + } + finally + { + if (_serverTransport is not null && _sessionTransport is not null) + { + // We created the _sessionTransport from the _serverTransport, so we own it. + await _sessionTransport.DisposeAsync().ConfigureAwait(false); + } } } diff --git a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs b/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs index e7a29f63..98567d08 100644 --- a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs +++ b/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs @@ -3,6 +3,7 @@ using ModelContextProtocol.Logging; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Utils; using System.Diagnostics.CodeAnalysis; namespace ModelContextProtocol.Shared; @@ -23,7 +24,7 @@ internal abstract class McpJsonRpcEndpoint : IAsyncDisposable private CancellationTokenSource? _sessionCts; private int _started; - private readonly SemaphoreSlim _disposeLock = new(1); + private readonly SemaphoreSlim _disposeLock = new(1, 1); private bool _disposed; protected readonly ILogger _logger; @@ -78,21 +79,15 @@ protected void StartSession(CancellationToken fullSessionCancellationToken = def public async ValueTask DisposeAsync() { - await _disposeLock.WaitAsync().ConfigureAwait(false); - try - { - if (_disposed) - { - return; - } - _disposed = true; + using var _ = await _disposeLock.LockAsync().ConfigureAwait(false); - await DisposeUnsynchronizedAsync().ConfigureAwait(false); - } - finally + if (_disposed) { - _disposeLock.Release(); + return; } + _disposed = true; + + await DisposeUnsynchronizedAsync().ConfigureAwait(false); } /// @@ -126,6 +121,6 @@ public virtual async ValueTask DisposeUnsynchronizedAsync() _logger.EndpointCleanedUp(EndpointName); } - protected McpSession GetSessionOrThrow() => - _session ?? throw new InvalidOperationException($"This should be unreachable from public API! Call {nameof(InitializeSession)} before sending messages."); + protected McpSession GetSessionOrThrow() + => _session ?? throw new InvalidOperationException($"This should be unreachable from public API! Call {nameof(InitializeSession)} before sending messages."); }