From 4ae26c04817d271f657c004b0c1ccb6c943cd370 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Thu, 27 Mar 2025 18:07:20 -0400 Subject: [PATCH] Enable graceful shutdown of servers --- README.md | 6 +- .../McpEndpointRouteBuilderExtensions.cs | 11 +- src/ModelContextProtocol/Client/McpClient.cs | 122 ++++++------- .../Hosting/McpServerHostedService.cs | 2 +- .../Transport/HttpListenerServerProvider.cs | 169 +++++++++--------- .../HttpListenerSseServerTransport.cs | 33 ++-- .../Protocol/Transport/IServerTransport.cs | 3 + .../Protocol/Transport/SseClientTransport.cs | 44 ++--- .../Transport/StdioClientTransport.cs | 1 - .../Transport/StdioServerTransport.cs | 157 ++++++++-------- src/ModelContextProtocol/Server/IMcpServer.cs | 9 +- src/ModelContextProtocol/Server/McpServer.cs | 38 ++-- .../Shared/McpJsonRpcEndpoint.cs | 10 +- .../Program.cs | 95 +++++----- .../Program.cs | 23 +-- .../Client/McpClientExtensionsTests.cs | 13 +- .../ClientIntegrationTests.cs | 39 ++-- .../McpServerBuilderExtensionsToolsTests.cs | 40 ++--- .../EverythingSseServerFixture.cs | 2 - .../Server/McpServerTests.cs | 48 ++--- .../Server/McpServerToolTests.cs | 3 +- .../TestAttributes.cs | 2 + .../Transport/SseClientTransportTests.cs | 6 +- .../Transport/StdioServerTransportTests.cs | 6 +- .../Utils/TestServerTransport.cs | 2 + 25 files changed, 391 insertions(+), 493 deletions(-) create mode 100644 tests/ModelContextProtocol.Tests/TestAttributes.cs diff --git a/README.md b/README.md index 63784d6e..502dd6b5 100644 --- a/README.md +++ b/README.md @@ -198,11 +198,7 @@ McpServerOptions options = new() }; await using IMcpServer server = McpServerFactory.Create(new StdioServerTransport("MyServer"), options); - -await server.StartAsync(); - -// Run until process is stopped by the client (parent process) -await Task.Delay(Timeout.Infinite); +await server.RunAsync(); ``` ## Acknowledgements diff --git a/samples/AspNetCoreSseServer/McpEndpointRouteBuilderExtensions.cs b/samples/AspNetCoreSseServer/McpEndpointRouteBuilderExtensions.cs index 9bf11dd8..8a37afba 100644 --- a/samples/AspNetCoreSseServer/McpEndpointRouteBuilderExtensions.cs +++ b/samples/AspNetCoreSseServer/McpEndpointRouteBuilderExtensions.cs @@ -10,7 +10,6 @@ public static class McpEndpointRouteBuilderExtensions { public static IEndpointConventionBuilder MapMcpSse(this IEndpointRouteBuilder endpoints) { - IMcpServer? server = null; SseResponseStreamTransport? transport = null; var loggerFactory = endpoints.ServiceProvider.GetRequiredService(); var mcpServerOptions = endpoints.ServiceProvider.GetRequiredService>(); @@ -19,17 +18,15 @@ public static IEndpointConventionBuilder MapMcpSse(this IEndpointRouteBuilder en routeGroup.MapGet("/sse", async (HttpResponse response, CancellationToken requestAborted) => { - await using var localTransport = transport = new SseResponseStreamTransport(response.Body); - await using var localServer = server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider); - - await localServer.StartAsync(requestAborted); - response.Headers.ContentType = "text/event-stream"; response.Headers.CacheControl = "no-cache"; + await using var localTransport = transport = new SseResponseStreamTransport(response.Body); + await using var server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider); + try { - await transport.RunAsync(requestAborted); + await transport.RunAsync(cancellationToken: requestAborted); } catch (OperationCanceledException) when (requestAborted.IsCancellationRequested) { diff --git a/src/ModelContextProtocol/Client/McpClient.cs b/src/ModelContextProtocol/Client/McpClient.cs index 1ae358ef..80126375 100644 --- a/src/ModelContextProtocol/Client/McpClient.cs +++ b/src/ModelContextProtocol/Client/McpClient.cs @@ -6,7 +6,6 @@ using ModelContextProtocol.Shared; using ModelContextProtocol.Utils.Json; using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; using System.Text.Json; namespace ModelContextProtocol.Client; @@ -17,7 +16,7 @@ internal sealed class McpClient : McpJsonRpcEndpoint, IMcpClient private readonly McpClientOptions _options; private readonly IClientTransport _clientTransport; - private volatile bool _isInitializing; + private int _connecting; /// /// Initializes a new instance of the class. @@ -74,33 +73,69 @@ public McpClient(IClientTransport transport, McpClientOptions options, McpServer /// public async Task ConnectAsync(CancellationToken cancellationToken = default) { - if (IsInitialized) - { - _logger.ClientAlreadyInitialized(EndpointName); - return; - } - - if (_isInitializing) + if (Interlocked.Exchange(ref _connecting, 1) != 0) { _logger.ClientAlreadyInitializing(EndpointName); - throw new InvalidOperationException("Client is already initializing"); + throw new InvalidOperationException("Client is already in use."); } - _isInitializing = true; + CancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + cancellationToken = CancellationTokenSource.Token; + try { - CancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - // Connect transport - await _clientTransport.ConnectAsync(CancellationTokenSource.Token).ConfigureAwait(false); + await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false); // Start processing messages - MessageProcessingTask = ProcessMessagesAsync(CancellationTokenSource.Token); + MessageProcessingTask = ProcessMessagesAsync(cancellationToken); // Perform initialization sequence - await InitializeAsync(CancellationTokenSource.Token).ConfigureAwait(false); + using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + initializationCts.CancelAfter(_options.InitializationTimeout); - IsInitialized = true; + try + { + // Send initialize request + var initializeResponse = await SendRequestAsync( + new JsonRpcRequest + { + Method = "initialize", + Params = new + { + protocolVersion = _options.ProtocolVersion, + capabilities = _options.Capabilities ?? new ClientCapabilities(), + clientInfo = _options.ClientInfo + } + }, + initializationCts.Token).ConfigureAwait(false); + + // Store server information + _logger.ServerCapabilitiesReceived(EndpointName, + capabilities: JsonSerializer.Serialize(initializeResponse.Capabilities, McpJsonUtilities.JsonContext.Default.ServerCapabilities), + serverInfo: JsonSerializer.Serialize(initializeResponse.ServerInfo, McpJsonUtilities.JsonContext.Default.Implementation)); + + ServerCapabilities = initializeResponse.Capabilities; + ServerInfo = initializeResponse.ServerInfo; + ServerInstructions = initializeResponse.Instructions; + + // Validate protocol version + if (initializeResponse.ProtocolVersion != _options.ProtocolVersion) + { + _logger.ServerProtocolVersionMismatch(EndpointName, _options.ProtocolVersion, initializeResponse.ProtocolVersion); + throw new McpClientException($"Server protocol version mismatch. Expected {_options.ProtocolVersion}, got {initializeResponse.ProtocolVersion}"); + } + + // Send initialized notification + await SendMessageAsync( + new JsonRpcNotification { Method = "notifications/initialized" }, + initializationCts.Token).ConfigureAwait(false); + } + catch (OperationCanceledException) when (initializationCts.IsCancellationRequested) + { + _logger.ClientInitializationTimeout(EndpointName); + throw new McpClientException("Initialization timed out"); + } } catch (Exception e) { @@ -108,58 +143,5 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) await CleanupAsync().ConfigureAwait(false); throw; } - finally - { - _isInitializing = false; - } - } - - private async Task InitializeAsync(CancellationToken cancellationToken) - { - using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - initializationCts.CancelAfter(_options.InitializationTimeout); - - try - { - // Send initialize request - var initializeResponse = await SendRequestAsync( - new JsonRpcRequest - { - Method = "initialize", - Params = new - { - protocolVersion = _options.ProtocolVersion, - capabilities = _options.Capabilities ?? new ClientCapabilities(), - clientInfo = _options.ClientInfo - } - }, - initializationCts.Token).ConfigureAwait(false); - - // Store server information - _logger.ServerCapabilitiesReceived(EndpointName, - capabilities: JsonSerializer.Serialize(initializeResponse.Capabilities, McpJsonUtilities.JsonContext.Default.ServerCapabilities), - serverInfo: JsonSerializer.Serialize(initializeResponse.ServerInfo, McpJsonUtilities.JsonContext.Default.Implementation)); - - ServerCapabilities = initializeResponse.Capabilities; - ServerInfo = initializeResponse.ServerInfo; - ServerInstructions = initializeResponse.Instructions; - - // Validate protocol version - if (initializeResponse.ProtocolVersion != _options.ProtocolVersion) - { - _logger.ServerProtocolVersionMismatch(EndpointName, _options.ProtocolVersion, initializeResponse.ProtocolVersion); - throw new McpClientException($"Server protocol version mismatch. Expected {_options.ProtocolVersion}, got {initializeResponse.ProtocolVersion}"); - } - - // Send initialized notification - await SendMessageAsync( - new JsonRpcNotification { Method = "notifications/initialized" }, - initializationCts.Token).ConfigureAwait(false); - } - catch (OperationCanceledException) when (initializationCts.IsCancellationRequested) - { - _logger.ClientInitializationTimeout(EndpointName); - throw new McpClientException("Initialization timed out"); - } } } diff --git a/src/ModelContextProtocol/Hosting/McpServerHostedService.cs b/src/ModelContextProtocol/Hosting/McpServerHostedService.cs index 8aa6218f..51c0d997 100644 --- a/src/ModelContextProtocol/Hosting/McpServerHostedService.cs +++ b/src/ModelContextProtocol/Hosting/McpServerHostedService.cs @@ -26,6 +26,6 @@ public McpServerHostedService(IMcpServer server) /// protected override async Task ExecuteAsync(CancellationToken stoppingToken) { - await _server.StartAsync(stoppingToken).ConfigureAwait(false); + await _server.RunAsync(cancellationToken: stoppingToken).ConfigureAwait(false); } } diff --git a/src/ModelContextProtocol/Protocol/Transport/HttpListenerServerProvider.cs b/src/ModelContextProtocol/Protocol/Transport/HttpListenerServerProvider.cs index 7b923c59..dbeb4e6c 100644 --- a/src/ModelContextProtocol/Protocol/Transport/HttpListenerServerProvider.cs +++ b/src/ModelContextProtocol/Protocol/Transport/HttpListenerServerProvider.cs @@ -1,22 +1,30 @@ using System.Net; -using ModelContextProtocol.Server; + +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously namespace ModelContextProtocol.Protocol.Transport; /// /// HTTP server provider using HttpListener. /// -internal class HttpListenerServerProvider : IDisposable +internal sealed class HttpListenerServerProvider : IAsyncDisposable { private static readonly byte[] s_accepted = "Accepted"u8.ToArray(); private const string SseEndpoint = "/sse"; private const string MessageEndpoint = "/message"; - private readonly int _port; - private HttpListener? _listener; - private CancellationTokenSource? _cts; - private bool _isRunning; + private readonly HttpListener _listener; + private readonly CancellationTokenSource _shutdownTokenSource = new(); + private Task _listeningTask = Task.CompletedTask; + + private readonly TaskCompletionSource _completed = new(); + private int _outstandingOperations; + + private int _state; + private const int StateNotStarted = 0; + private const int StateRunning = 1; + private const int StateStopped = 2; /// /// Creates a new instance of the HTTP server provider. @@ -24,84 +32,100 @@ internal class HttpListenerServerProvider : IDisposable /// The port to listen on public HttpListenerServerProvider(int port) { - _port = port; + if (port < 0) + { + throw new ArgumentOutOfRangeException(nameof(port)); + } + + _listener = new(); + _listener.Prefixes.Add($"http://localhost:{port}/"); } public required Func OnSseConnectionAsync { get; set; } public required Func> OnMessageAsync { get; set; } /// - public Task StartAsync(CancellationToken cancellationToken = default) + public async Task StartAsync(CancellationToken cancellationToken = default) { - if (_isRunning) - return Task.CompletedTask; + if (Interlocked.CompareExchange(ref _state, StateRunning, StateNotStarted) != StateNotStarted) + { + throw new ObjectDisposedException("Server may not be started twice."); + } - _cts = new CancellationTokenSource(); - _listener = new HttpListener(); - _listener.Prefixes.Add($"http://localhost:{_port}/"); + // Start listening for connections _listener.Start(); - _isRunning = true; - // Start listening for connections - _ = Task.Run(() => ListenForConnectionsAsync(_cts.Token), cancellationToken).ConfigureAwait(false); - return Task.CompletedTask; + OperationAdded(); // for the listening task + _listeningTask = Task.Run(async () => + { + try + { + using var cts = CancellationTokenSource.CreateLinkedTokenSource(_shutdownTokenSource.Token, cancellationToken); + cts.Token.Register(_listener.Stop); + while (!cts.IsCancellationRequested) + { + try + { + var context = await _listener.GetContextAsync().ConfigureAwait(false); + + // Process the request in a separate task + OperationAdded(); // for the processing task; decremented in ProcessRequestAsync + _ = Task.Run(() => ProcessRequestAsync(context, cts.Token), CancellationToken.None); + } + catch (Exception) + { + if (cts.IsCancellationRequested) + { + // Shutdown requested, exit gracefully + break; + } + } + } + } + finally + { + OperationCompleted(); // for the listening task + } + }, CancellationToken.None); } /// - public Task StopAsync(CancellationToken cancellationToken = default) + public async ValueTask DisposeAsync() { - if (!_isRunning) - return Task.CompletedTask; - - _cts?.Cancel(); - _listener?.Stop(); + if (Interlocked.CompareExchange(ref _state, StateStopped, StateRunning) != StateRunning) + { + return; + } - _isRunning = false; - return Task.CompletedTask; + await _shutdownTokenSource.CancelAsync().ConfigureAwait(false); + _listener.Stop(); + await _listeningTask.ConfigureAwait(false); } - private async Task ListenForConnectionsAsync(CancellationToken cancellationToken) - { - if (_listener == null) - { - throw new McpServerException("Listener not initialized"); - } + /// Gets a that completes when the server has finished its work. + public Task Completed => _completed.Task; - while (!cancellationToken.IsCancellationRequested) - { - try - { - var context = await _listener.GetContextAsync().ConfigureAwait(false); + private void OperationAdded() => Interlocked.Increment(ref _outstandingOperations); - // Process the request in a separate task - _ = Task.Run(() => ProcessRequestAsync(context, cancellationToken), cancellationToken); - } - catch (Exception) when (cancellationToken.IsCancellationRequested) - { - // Shutdown requested, exit gracefully - break; - } - catch (Exception) - { - // Log error but continue listening - if (!cancellationToken.IsCancellationRequested) - { - // Continue listening if not shutting down - continue; - } - } + private void OperationCompleted() + { + if (Interlocked.Decrement(ref _outstandingOperations) == 0) + { + // All operations completed + _completed.TrySetResult(true); } } private async Task ProcessRequestAsync(HttpListenerContext context, CancellationToken cancellationToken) { + var request = context.Request; + var response = context.Response; try { - var request = context.Request; - var response = context.Response; - - if (request == null) - throw new McpServerException("Request is null"); + if (request is null || response is null) + { + return; + } // Handle SSE connection if (request.HttpMethod == "GET" && request.Url?.LocalPath == SseEndpoint) @@ -124,11 +148,15 @@ private async Task ProcessRequestAsync(HttpListenerContext context, Cancellation { try { - context.Response.StatusCode = 500; - context.Response.Close(); + response.StatusCode = 500; + response.Close(); } catch { /* Ignore errors during error handling */ } } + finally + { + OperationCompleted(); + } } private async Task HandleSseConnectionAsync(HttpListenerContext context, CancellationToken cancellationToken) @@ -145,13 +173,8 @@ private async Task HandleSseConnectionAsync(HttpListenerContext context, Cancell { await OnSseConnectionAsync(response.OutputStream, cancellationToken).ConfigureAwait(false); } - catch (TaskCanceledException) - { - // Normal shutdown - } catch (Exception) { - // Client disconnected or other error } finally { @@ -186,20 +209,4 @@ private async Task HandleMessageAsync(HttpListenerContext context, CancellationT response.Close(); } - - - /// - public void Dispose() - { - Dispose(true); - GC.SuppressFinalize(this); - } - - /// - protected virtual void Dispose(bool disposing) - { - StopAsync().GetAwaiter().GetResult(); - _cts?.Dispose(); - _listener?.Close(); - } } diff --git a/src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerTransport.cs index c10b751d..44280608 100644 --- a/src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerTransport.cs @@ -7,6 +7,8 @@ using ModelContextProtocol.Server; using ModelContextProtocol.Utils; +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + namespace ModelContextProtocol.Protocol.Transport; /// @@ -28,7 +30,7 @@ public sealed class HttpListenerSseServerTransport : TransportBase, IServerTrans /// The port to listen on. /// A logger factory for creating loggers. public HttpListenerSseServerTransport(McpServerOptions serverOptions, int port, ILoggerFactory loggerFactory) - : this(GetServerName(serverOptions), port, loggerFactory) + : this(serverOptions?.ServerInfo?.Name!, port, loggerFactory) { } @@ -41,6 +43,8 @@ public HttpListenerSseServerTransport(McpServerOptions serverOptions, int port, public HttpListenerSseServerTransport(string serverName, int port, ILoggerFactory loggerFactory) : base(loggerFactory) { + Throw.IfNull(serverName); + _serverName = serverName; _logger = loggerFactory.CreateLogger(); _httpServerProvider = new HttpListenerServerProvider(port) @@ -51,10 +55,11 @@ public HttpListenerSseServerTransport(string serverName, int port, ILoggerFactor } /// - public Task StartListeningAsync(CancellationToken cancellationToken = default) - { - return _httpServerProvider.StartAsync(cancellationToken); - } + public Task StartListeningAsync(CancellationToken cancellationToken = default) => + _httpServerProvider.StartAsync(cancellationToken); + + /// + public Task Completion => _httpServerProvider.Completed; /// public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) @@ -92,20 +97,13 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio /// public override async ValueTask DisposeAsync() - { - await CleanupAsync(CancellationToken.None).ConfigureAwait(false); - GC.SuppressFinalize(this); - } - - private Task CleanupAsync(CancellationToken cancellationToken) { _logger.TransportCleaningUp(EndpointName); - _httpServerProvider.Dispose(); SetConnected(false); + await _httpServerProvider.DisposeAsync().ConfigureAwait(false); _logger.TransportCleanedUp(EndpointName); - return Task.CompletedTask; } private async Task OnSseConnectionAsync(Stream responseStream, CancellationToken cancellationToken) @@ -167,13 +165,4 @@ private async Task OnMessageAsync(Stream requestStream, CancellationToken return false; } } - - /// Validates the and extracts from it the server name to use. - private static string GetServerName(McpServerOptions serverOptions) - { - Throw.IfNull(serverOptions); - Throw.IfNull(serverOptions.ServerInfo); - - return serverOptions.ServerInfo.Name; - } } diff --git a/src/ModelContextProtocol/Protocol/Transport/IServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/IServerTransport.cs index 742c8942..e204202e 100644 --- a/src/ModelContextProtocol/Protocol/Transport/IServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/IServerTransport.cs @@ -10,4 +10,7 @@ public interface IServerTransport : ITransport /// /// Token to cancel the operation. Task StartListeningAsync(CancellationToken cancellationToken = default); + + /// Gets a that will complete when the server transport has completed all work. + Task Completion { get; } } diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs index 5d72d0c0..654b3a6b 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs @@ -21,7 +21,7 @@ public sealed class SseClientTransport : TransportBase, IClientTransport private readonly SseClientTransportOptions _options; private readonly Uri _sseEndpoint; private Uri? _messageEndpoint; - private CancellationTokenSource? _connectionCts; + private readonly CancellationTokenSource _connectionCts; private Task? _receiveTask; private readonly ILogger _logger; private readonly McpServerConfig _serverConfig; @@ -82,7 +82,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) } // Start message receiving loop - _receiveTask = ReceiveMessagesAsync(_connectionCts!.Token); + _receiveTask = ReceiveMessagesAsync(_connectionCts.Token); _logger.TransportReadingMessages(EndpointName); @@ -165,19 +165,25 @@ public override async Task SendMessageAsync( } } - /// - public async Task CloseAsync() + private async Task CloseAsync() { - if (_connectionCts != null) + try { - await _connectionCts.CancelAsync().ConfigureAwait(false); - _connectionCts.Dispose(); - _connectionCts = null; - } - if (_receiveTask != null) - await _receiveTask.ConfigureAwait(false); + if (!_connectionCts.IsCancellationRequested) + { + await _connectionCts.CancelAsync().ConfigureAwait(false); + _connectionCts.Dispose(); + } - SetConnected(false); + if (_receiveTask != null) + { + await _receiveTask.ConfigureAwait(false); + } + } + finally + { + SetConnected(false); + } } /// @@ -185,19 +191,17 @@ public override async ValueTask DisposeAsync() { try { - await CloseAsync().ConfigureAwait(false); + if (_ownsHttpClient) + { + _httpClient?.Dispose(); + } + + await CloseAsync(); } catch (Exception) { // Ignore exceptions on close } - - if (_ownsHttpClient) - _httpClient?.Dispose(); - - _connectionCts?.Dispose(); - - GC.SuppressFinalize(this); } internal Uri? MessageEndpoint => _messageEndpoint; diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs index 49e43301..00f4dfb4 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs @@ -183,7 +183,6 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio public override async ValueTask DisposeAsync() { await CleanupAsync(CancellationToken.None).ConfigureAwait(false); - GC.SuppressFinalize(this); } private async Task ReadMessagesAsync(CancellationToken cancellationToken) diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs index e1f517e5..05cd5a28 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs @@ -25,9 +25,12 @@ public sealed class StdioServerTransport : TransportBase, IServerTransport private readonly TextReader _stdInReader; private readonly Stream _stdOutStream; - private SemaphoreSlim _sendLock = new(1, 1); - private Task? _readTask; - private CancellationTokenSource? _shutdownCts; + private readonly SemaphoreSlim _sendLock = new(1, 1); + private readonly CancellationTokenSource _shutdownCts = new(); + + private Task _readLoopCompleted = Task.CompletedTask; + private TaskCompletionSource _serverCompleted = new(); + private int _disposed = 0; private string EndpointName => $"Server (stdio) ({_serverName})"; @@ -44,8 +47,8 @@ public sealed class StdioServerTransport : TransportBase, IServerTransport /// to , as that will interfere with the transport's output. /// /// - public StdioServerTransport(McpServerOptions serverOptions, ILoggerFactory? loggerFactory = null) - : this(GetServerName(serverOptions), loggerFactory) + public StdioServerTransport(IOptions serverOptions, ILoggerFactory? loggerFactory = null) + : this(serverOptions?.Value!, loggerFactory: loggerFactory) { } @@ -62,9 +65,17 @@ public StdioServerTransport(McpServerOptions serverOptions, ILoggerFactory? logg /// to , as that will interfere with the transport's output. /// /// - public StdioServerTransport(IOptions serverOptions, ILoggerFactory? loggerFactory = null) - : this(GetServerName(serverOptions.Value), loggerFactory) + public StdioServerTransport(McpServerOptions serverOptions, ILoggerFactory? loggerFactory = null) + : this(GetServerName(serverOptions), loggerFactory: loggerFactory) + { + } + + private static string GetServerName(McpServerOptions serverOptions) { + Throw.IfNull(serverOptions); + Throw.IfNull(serverOptions.ServerInfo); + Throw.IfNull(serverOptions.ServerInfo.Name); + return serverOptions.ServerInfo.Name; } /// @@ -80,25 +91,17 @@ public StdioServerTransport(IOptions serverOptions, ILoggerFac /// to , as that will interfere with the transport's output. /// /// - public StdioServerTransport(string serverName, ILoggerFactory? loggerFactory = null) - : base(loggerFactory) + public StdioServerTransport(string serverName, ILoggerFactory? loggerFactory) + : this(serverName, stdinStream: null, stdoutStream: null, loggerFactory) { - Throw.IfNull(serverName); - - _serverName = serverName; - _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; - - // Get raw console streams and wrap them with UTF-8 encoding - _stdInReader = new StreamReader(Console.OpenStandardInput(), Encoding.UTF8); - _stdOutStream = new BufferedStream(Console.OpenStandardOutput()); } /// /// Initializes a new instance of the class with explicit input/output streams. /// /// The name of the server. - /// The input TextReader to use. - /// The output TextWriter to use. + /// The input to use as standard input. If , will be used. + /// The output to use as standard output. If , will be used. /// Optional logger factory used for logging employed by the transport. /// is . /// @@ -106,35 +109,34 @@ public StdioServerTransport(string serverName, ILoggerFactory? loggerFactory = n /// This constructor is useful for testing scenarios where you want to redirect input/output. /// /// - public StdioServerTransport(string serverName, Stream stdinStream, Stream stdoutStream, ILoggerFactory? loggerFactory = null) + public StdioServerTransport(string serverName, Stream? stdinStream = null, Stream? stdoutStream = null, ILoggerFactory? loggerFactory = null) : base(loggerFactory) { Throw.IfNull(serverName); - Throw.IfNull(stdinStream); - Throw.IfNull(stdoutStream); _serverName = serverName; _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; - - _stdInReader = new StreamReader(stdinStream, Encoding.UTF8); - _stdOutStream = stdoutStream; + + _stdInReader = new StreamReader(stdinStream ?? Console.OpenStandardInput(), Encoding.UTF8); + _stdOutStream = stdoutStream ?? new BufferedStream(Console.OpenStandardOutput()); } /// public Task StartListeningAsync(CancellationToken cancellationToken = default) { _logger.LogDebug("Starting StdioServerTransport listener for {EndpointName}", EndpointName); - - _shutdownCts = new CancellationTokenSource(); - - _readTask = Task.Run(async () => await ReadMessagesAsync(_shutdownCts.Token).ConfigureAwait(false), CancellationToken.None); SetConnected(true); + _readLoopCompleted = Task.Run(ReadMessagesAsync, _shutdownCts.Token); + _logger.LogDebug("StdioServerTransport now connected for {EndpointName}", EndpointName); return Task.CompletedTask; } + /// Gets a that completes when the server transport has finished its work. + public Task Completion => _serverCompleted.Task; + /// public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) { @@ -169,33 +171,26 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio } } - /// - public override async ValueTask DisposeAsync() - { - await CleanupAsync(CancellationToken.None).ConfigureAwait(false); - GC.SuppressFinalize(this); - } - - private async Task ReadMessagesAsync(CancellationToken cancellationToken) + private async Task ReadMessagesAsync() { + CancellationToken shutdownToken = _shutdownCts.Token; try { _logger.TransportEnteringReadMessagesLoop(EndpointName); - while (!cancellationToken.IsCancellationRequested) + while (!shutdownToken.IsCancellationRequested) { _logger.TransportWaitingForMessage(EndpointName); - var reader = _stdInReader; - var line = await reader.ReadLineAsync(cancellationToken).ConfigureAwait(false); - if (line == null) - { - _logger.TransportEndOfStream(EndpointName); - break; - } - + var line = await _stdInReader.ReadLineAsync(shutdownToken).ConfigureAwait(false); if (string.IsNullOrWhiteSpace(line)) { + if (line is null) + { + _logger.TransportEndOfStream(EndpointName); + break; + } + continue; } @@ -204,8 +199,7 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken) try { - var message = JsonSerializer.Deserialize(line, _jsonOptions.GetTypeInfo()); - if (message != null) + if (JsonSerializer.Deserialize(line, _jsonOptions.GetTypeInfo()) is { } message) { string messageId = "(no id)"; if (message is IJsonRpcMessageWithId messageWithId) @@ -213,7 +207,8 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken) messageId = messageWithId.Id.ToString(); } _logger.TransportReceivedMessageParsed(EndpointName, messageId); - await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); + + await WriteMessageAsync(message, shutdownToken).ConfigureAwait(false); _logger.TransportMessageWritten(EndpointName, messageId); } else @@ -227,76 +222,70 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken) // Continue reading even if we fail to parse a message } } + _logger.TransportExitingReadMessagesLoop(EndpointName); } - catch (OperationCanceledException) + catch (OperationCanceledException oce) { _logger.TransportReadMessagesCancelled(EndpointName); - // Normal shutdown + _serverCompleted.TrySetCanceled(oce.CancellationToken); } catch (Exception ex) { _logger.TransportReadMessagesFailed(EndpointName, ex); + _serverCompleted.TrySetException(ex); } finally { - await CleanupAsync(cancellationToken).ConfigureAwait(false); + _serverCompleted.TrySetResult(true); } } - private async Task CleanupAsync(CancellationToken cancellationToken) + /// + public override async ValueTask DisposeAsync() { - _logger.TransportCleaningUp(EndpointName); - - if (_shutdownCts is { } shutdownCts) + if (Interlocked.Exchange(ref _disposed, 1) != 0) { - await shutdownCts.CancelAsync().ConfigureAwait(false); - shutdownCts.Dispose(); - - _shutdownCts = null; + return; } - if (_readTask is { } readTask) + try { + _logger.TransportCleaningUp(EndpointName); + + // Signal to the stdin reading loop to stop. + await _shutdownCts.CancelAsync().ConfigureAwait(false); + _shutdownCts.Dispose(); + + // Dispose of stdin/out. Cancellation may not be able to wake up operations + // synchronously blocked in a syscall; we need to forcefully close the handle / file descriptor. + _stdInReader?.Dispose(); + _stdOutStream?.Dispose(); + + // Make sure the work has quiesced. try { _logger.TransportWaitingForReadTask(EndpointName); - await readTask.WaitAsync(TimeSpan.FromSeconds(5), cancellationToken).ConfigureAwait(false); + await _readLoopCompleted.ConfigureAwait(false); + _logger.TransportReadTaskCleanedUp(EndpointName); } catch (TimeoutException) { _logger.TransportCleanupReadTaskTimeout(EndpointName); - // Continue with cleanup } catch (OperationCanceledException) { _logger.TransportCleanupReadTaskCancelled(EndpointName); - // Ignore cancellation } catch (Exception ex) { _logger.TransportCleanupReadTaskFailed(EndpointName, ex); } - finally - { - _logger.TransportReadTaskCleanedUp(EndpointName); - _readTask = null; - } } - - _stdInReader?.Dispose(); - _stdOutStream?.Dispose(); - - SetConnected(false); - _logger.TransportCleanedUp(EndpointName); - } - - /// Validates the and extracts from it the server name to use. - private static string GetServerName(McpServerOptions serverOptions) - { - Throw.IfNull(serverOptions); - Throw.IfNull(serverOptions.ServerInfo); - - return serverOptions.ServerInfo.Name; + finally + { + SetConnected(false); + _logger.TransportCleanedUp(EndpointName); + } } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Server/IMcpServer.cs b/src/ModelContextProtocol/Server/IMcpServer.cs index 08d4d990..e311b04a 100644 --- a/src/ModelContextProtocol/Server/IMcpServer.cs +++ b/src/ModelContextProtocol/Server/IMcpServer.cs @@ -8,11 +8,6 @@ namespace ModelContextProtocol.Server; /// public interface IMcpServer : IAsyncDisposable { - /// - /// Gets a value indicating whether the server has been initialized. - /// - bool IsInitialized { get; } - /// /// Gets the capabilities supported by the client. /// @@ -48,9 +43,9 @@ public interface IMcpServer : IAsyncDisposable void AddNotificationHandler(string method, Func handler); /// - /// Starts the server and begins listening for client requests. + /// Runs the server, listening for and handling client requests. /// - Task StartAsync(CancellationToken cancellationToken = default); + Task RunAsync(Func? onInitialized = null, CancellationToken cancellationToken = default); /// /// Sends a generic JSON-RPC request to the client. diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index 7f425b1f..5a8f50f6 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -1,5 +1,4 @@ using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Logging; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; @@ -13,11 +12,10 @@ namespace ModelContextProtocol.Server; /// internal sealed class McpServer : McpJsonRpcEndpoint, IMcpServer { - private readonly IServerTransport? _serverTransport; private readonly string _serverDescription; private readonly EventHandler? _toolsChangedDelegate; - private volatile bool _isInitializing; + private volatile bool _ran; /// /// Creates a new instance of . @@ -33,7 +31,6 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? { Throw.IfNull(options); - _serverTransport = transport as IServerTransport; ServerOptions = options; Services = serviceProvider; _serverDescription = $"{options.ServerInfo.Name} {options.ServerInfo.Version}"; @@ -52,7 +49,6 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? tools.Changed += _toolsChangedDelegate; } - IsInitialized = true; return Task.CompletedTask; }); @@ -68,6 +64,7 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? public ServerCapabilities? ServerCapabilities { get; set; } + /// public ClientCapabilities? ClientCapabilities { get; set; } /// @@ -84,35 +81,40 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? $"Server ({_serverDescription}), Client ({ClientInfo?.Name} {ClientInfo?.Version})"; /// - public async Task StartAsync(CancellationToken cancellationToken = default) + public async Task RunAsync(Func? onInitialized = null, CancellationToken cancellationToken = default) { - if (_isInitializing) - { - _logger.ServerAlreadyInitializing(EndpointName); - throw new InvalidOperationException("Server is already initializing"); - } - _isInitializing = true; - - if (IsInitialized) + if (_ran) { _logger.ServerAlreadyInitializing(EndpointName); - return; + throw new InvalidOperationException("Server is already running."); } + _ran = true; try { CancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - if (_serverTransport is not null) + IServerTransport? serverTransport = Transport as IServerTransport; + if (serverTransport is not null) { // Start listening for messages - await _serverTransport.StartListeningAsync(CancellationTokenSource.Token).ConfigureAwait(false); + await serverTransport.StartListeningAsync(CancellationTokenSource.Token).ConfigureAwait(false); } // Start processing messages MessageProcessingTask = ProcessMessagesAsync(CancellationTokenSource.Token); - // Unlike McpClient, we're not done initializing until we've received a message from the client, so we don't set IsInitialized here + // Invoke any post-initialization logic. + if (onInitialized is not null) + { + await onInitialized(CancellationTokenSource.Token).ConfigureAwait(false); + } + + // Wait for transport completion. + if (serverTransport is not null) + { + await serverTransport.Completion; + } } catch (Exception e) { diff --git a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs b/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs index 90824c58..7052953d 100644 --- a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs +++ b/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs @@ -49,11 +49,6 @@ protected McpJsonRpcEndpoint(ITransport transport, ILoggerFactory? loggerFactory _logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance; } - /// - /// Gets whether the endpoint is initialized and ready to process messages. - /// - public bool IsInitialized { get; set; } - /// /// Gets the name of the endpoint for logging and debug purposes. /// @@ -320,7 +315,6 @@ public void AddNotificationHandler(string method, Func @@ -369,7 +363,9 @@ protected virtual async Task CleanupAsync() _logger.CleaningUpEndpoint(EndpointName); if (CancellationTokenSource != null) + { await CancellationTokenSource.CancelAsync().ConfigureAwait(false); + } if (MessageProcessingTask != null) { @@ -393,8 +389,6 @@ protected virtual async Task CleanupAsync() await _transport.DisposeAsync().ConfigureAwait(false); CancellationTokenSource?.Dispose(); - IsInitialized = false; - _logger.EndpointCleanedUp(EndpointName); } } diff --git a/tests/ModelContextProtocol.TestServer/Program.cs b/tests/ModelContextProtocol.TestServer/Program.cs index e0558447..923a2b9d 100644 --- a/tests/ModelContextProtocol.TestServer/Program.cs +++ b/tests/ModelContextProtocol.TestServer/Program.cs @@ -4,6 +4,7 @@ using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using Serilog; +using System.Collections.Concurrent; using System.Text; using System.Text.Json; @@ -49,59 +50,50 @@ private static async Task Main(string[] args) using var loggerFactory = CreateLoggerFactory(); await using IMcpServer server = McpServerFactory.Create(new StdioServerTransport("TestServer", loggerFactory), options, loggerFactory); - Log.Logger.Information("Server initialized."); - - await server.StartAsync(); - - Log.Logger.Information("Server started."); - - // everything server sends random log level messages every 15 seconds - int loggingSeconds = 0; - Random random = Random.Shared; - var loggingLevels = Enum.GetValues().ToList(); - - // Run until process is stopped by the client (parent process) - while (true) + Log.Logger.Information("Server running..."); + await server.RunAsync(async cancellationToken => { - await Task.Delay(5000); - if (_minimumLoggingLevel is not null) - { - loggingSeconds += 5; + var loggingLevels = Enum.GetValues(); - // Send random log messages every 15 seconds - if (loggingSeconds >= 15) + // Run until process is stopped by the client (parent process) + while (true) + { + await Task.Delay(1000, cancellationToken); + try { - var logLevelIndex = random.Next(loggingLevels.Count); - var logLevel = loggingLevels[logLevelIndex]; - await server.SendMessageAsync(new JsonRpcNotification() + // Send random log messages every few seconds + if (_minimumLoggingLevel is not null) { - Method = NotificationMethods.LoggingMessageNotification, - Params = new LoggingMessageNotificationParams + var logLevel = loggingLevels[Random.Shared.Next(loggingLevels.Length)]; + await server.SendMessageAsync(new JsonRpcNotification() { - Level = logLevel, - Data = JsonSerializer.Deserialize("\"Random log message\"") - } - }); - } - } + Method = NotificationMethods.LoggingMessageNotification, + Params = new LoggingMessageNotificationParams + { + Level = logLevel, + Data = JsonSerializer.Deserialize("\"Random log message\"") + } + }, cancellationToken); + } - // Snapshot the subscribed resources, rather than locking while sending notifications - List resources; - lock (_subscribedResourcesLock) - { - resources = _subscribedResources.ToList(); - } - - foreach (var resource in resources) - { - ResourceUpdatedNotificationParams notificationParams = new() { Uri = resource }; - await server.SendMessageAsync(new JsonRpcNotification() + // Snapshot the subscribed resources, rather than locking while sending notifications + foreach (var resource in _subscribedResources) + { + ResourceUpdatedNotificationParams notificationParams = new() { Uri = resource.Key }; + await server.SendMessageAsync(new JsonRpcNotification() + { + Method = NotificationMethods.ResourceUpdatedNotification, + Params = notificationParams + }, cancellationToken); + } + } + catch (Exception ex) { - Method = NotificationMethods.ResourceUpdatedNotification, - Params = notificationParams - }); + Log.Logger.Error(ex, "Error sending log message"); + break; + } } - } + }); } private static ToolsCapability ConfigureTools() @@ -312,8 +304,7 @@ private static LoggingCapability ConfigureLogging() }; } - private static readonly HashSet _subscribedResources = new(); - private static readonly object _subscribedResourcesLock = new(); + private static readonly ConcurrentDictionary _subscribedResources = new(); private static ResourcesCapability ConfigureResources() { @@ -451,10 +442,7 @@ private static ResourcesCapability ConfigureResources() throw new McpServerException("Invalid resource URI"); } - lock (_subscribedResourcesLock) - { - _subscribedResources.Add(request.Params.Uri); - } + _subscribedResources.TryAdd(request.Params.Uri, true); return Task.FromResult(new EmptyResult()); }, @@ -471,10 +459,7 @@ private static ResourcesCapability ConfigureResources() throw new McpServerException("Invalid resource URI"); } - lock (_subscribedResourcesLock) - { - _subscribedResources.Remove(request.Params.Uri); - } + _subscribedResources.Remove(request.Params.Uri, out _); return Task.FromResult(new EmptyResult()); }, diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index 598b4adf..98d5dfaf 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -44,8 +44,6 @@ public static async Task MainAsync(string[] args, ILoggerFactory? loggerFactory ServerInstructions = "This is a test server with only stub functionality" }; - IMcpServer? server = null; - Console.WriteLine("Registering handlers."); #region Helped method @@ -186,7 +184,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st { throw new McpServerException("Missing required arguments 'prompt' and 'maxTokens'"); } - var sampleResult = await server!.RequestSamplingAsync(CreateRequestSamplingParams(prompt?.ToString() ?? "", "sampleLLM", Convert.ToInt32(maxTokens?.ToString())), + var sampleResult = await request.Server.RequestSamplingAsync(CreateRequestSamplingParams(prompt?.ToString() ?? "", "sampleLLM", Convert.ToInt32(maxTokens?.ToString())), cancellationToken); return new CallToolResponse() @@ -384,23 +382,10 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st }; loggerFactory ??= CreateLoggerFactory(); - server = McpServerFactory.Create(new HttpListenerSseServerTransport("TestServer", 3001, loggerFactory), options, loggerFactory); - - Console.WriteLine("Server initialized."); - - await server.StartAsync(cancellationToken); - - Console.WriteLine("Server started."); + await using IMcpServer server = McpServerFactory.Create(new HttpListenerSseServerTransport("TestServer", 3001, loggerFactory), options, loggerFactory); - try - { - // Run until process is stopped by the client (parent process) or test - await Task.Delay(Timeout.Infinite, cancellationToken); - } - finally - { - await server.DisposeAsync(); - } + Console.WriteLine("Server running..."); + await server.RunAsync(cancellationToken: cancellationToken); } const string MCP_TINY_IMAGE = diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs index 03d0e993..ec4808ff 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs @@ -13,6 +13,8 @@ public class McpClientExtensionsTests private Pipe _clientToServerPipe = new(); private Pipe _serverToClientPipe = new(); private readonly IMcpServer _server; + private readonly CancellationTokenSource _cts; + private readonly Task _serverTask; public McpClientExtensionsTests() { @@ -25,19 +27,22 @@ public McpClientExtensionsTests() sc.AddSingleton(McpServerTool.Create((int i) => $"{name} Result {i}", name)); } _server = sc.BuildServiceProvider().GetRequiredService(); + + _cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); + _serverTask = _server.RunAsync(cancellationToken: _cts.Token); } - public ValueTask DisposeAsync() + public async ValueTask DisposeAsync() { + _cts.Cancel(); _clientToServerPipe.Writer.Complete(); _serverToClientPipe.Writer.Complete(); - return _server.DisposeAsync(); + await _serverTask; + await _server.DisposeAsync(); } private async Task CreateMcpClientForServer() { - await _server.StartAsync(TestContext.Current.CancellationToken); - var serverStdinWriter = new StreamWriter(_clientToServerPipe.Writer.AsStream()); var serverStdoutReader = new StreamReader(_serverToClientPipe.Reader.AsStream()); diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index 69321675..9ecfdd57 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -252,21 +252,17 @@ public async Task SubscribeResource_Stdio() var clientId = "test_server"; // act - int counter = 0; + TaskCompletionSource tcs = new(); await using var client = await _fixture.CreateClientAsync(clientId); client.AddNotificationHandler(NotificationMethods.ResourceUpdatedNotification, (notification) => { var notificationParams = JsonSerializer.Deserialize(notification.Params!.ToString() ?? string.Empty); - ++counter; + tcs.TrySetResult(true); return Task.CompletedTask; }); await client.SubscribeToResourceAsync("test://static/resource/1", CancellationToken.None); - // notifications happen every 5 seconds, so we wait for 10 seconds to ensure we get at least one notification - await Task.Delay(10000, TestContext.Current.CancellationToken); - - // assert - Assert.True(counter > 0); + await tcs.Task; } // Not supported by "everything" server version on npx @@ -277,32 +273,26 @@ public async Task UnsubscribeResource_Stdio() var clientId = "test_server"; // act - int counter = 0; + TaskCompletionSource receivedNotification = new(); await using var client = await _fixture.CreateClientAsync(clientId); client.AddNotificationHandler(NotificationMethods.ResourceUpdatedNotification, (notification) => { var notificationParams = JsonSerializer.Deserialize(notification.Params!.ToString() ?? string.Empty); - ++counter; + receivedNotification.TrySetResult(true); return Task.CompletedTask; }); await client.SubscribeToResourceAsync("test://static/resource/1", CancellationToken.None); - // notifications happen every 5 seconds, so we wait for 10 seconds to ensure we get at least one notification - await Task.Delay(10000, TestContext.Current.CancellationToken); - - // reset counter - int counterAfterSubscribe = counter; + // wait until we received a notification + await receivedNotification.Task; // unsubscribe await client.UnsubscribeFromResourceAsync("test://static/resource/1", CancellationToken.None); - counter = 0; + receivedNotification = new(); - // notifications happen every 5 seconds, so we wait for 10 seconds to ensure we would've gotten at least one notification - await Task.Delay(10000, TestContext.Current.CancellationToken); - - // assert - Assert.True(counterAfterSubscribe > 0); - Assert.Equal(0, counter); + // wait a bit to validate we don't receive another. this is best effort only; + // false negatives are possible. + await Assert.ThrowsAsync(() => receivedNotification.Task.WaitAsync(TimeSpan.FromSeconds(1), TestContext.Current.CancellationToken)); } [Theory] @@ -556,7 +546,7 @@ public async Task SetLoggingLevel_ReceivesLoggingMessages(string clientId) Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping, }; - int logCounter = 0; + TaskCompletionSource receivedNotification = new(); await using var client = await _fixture.CreateClientAsync(clientId); client.AddNotificationHandler(NotificationMethods.LoggingMessageNotification, (notification) => { @@ -564,16 +554,15 @@ public async Task SetLoggingLevel_ReceivesLoggingMessages(string clientId) jsonSerializerOptions); if (loggingMessageNotificationParameters is not null) { - ++logCounter; + receivedNotification.TrySetResult(true); } return Task.CompletedTask; }); // act await client.SetLoggingLevel(LoggingLevel.Debug, CancellationToken.None); - await Task.Delay(16000, TestContext.Current.CancellationToken); // assert - Assert.True(logCounter > 0); + await receivedNotification.Task; } } diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 3ae31301..18c08576 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -24,6 +24,8 @@ public class McpServerBuilderExtensionsToolsTests : LoggedTest, IAsyncDisposable private readonly ServiceProvider _serviceProvider; private readonly IMcpServerBuilder _builder; private readonly IMcpServer _server; + private readonly CancellationTokenSource _cts; + private readonly Task _serverTask; public McpServerBuilderExtensionsToolsTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) @@ -35,19 +37,24 @@ public McpServerBuilderExtensionsToolsTests(ITestOutputHelper testOutputHelper) _builder = sc.AddMcpServer().WithTools(); _serviceProvider = sc.BuildServiceProvider(); _server = _serviceProvider.GetRequiredService(); + + _cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); + _serverTask = _server.RunAsync(cancellationToken: _cts.Token); } - public ValueTask DisposeAsync() + public async ValueTask DisposeAsync() { + await _cts.CancelAsync(); + _cts.Dispose(); + _clientToServerPipe.Writer.Complete(); _serverToClientPipe.Writer.Complete(); - return _serviceProvider.DisposeAsync(); + + await _serviceProvider.DisposeAsync(); } private async Task CreateMcpClientForServer() { - await _server.StartAsync(TestContext.Current.CancellationToken); - var serverStdinWriter = new StreamWriter(_clientToServerPipe.Writer.AsStream()); var serverStdoutReader = new StreamReader(_serverToClientPipe.Reader.AsStream()); @@ -105,15 +112,13 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T var stdinPipe = new Pipe(); var stdoutPipe = new Pipe(); - try - { - var transport = new StdioServerTransport($"TestServer_{i}", stdinPipe.Reader.AsStream(), stdoutPipe.Writer.AsStream()); - var server = McpServerFactory.Create(transport, options, loggerFactory, _serviceProvider); - - await server.StartAsync(TestContext.Current.CancellationToken); + var transport = new StdioServerTransport($"TestServer_{i}", stdinPipe.Reader.AsStream(), stdoutPipe.Writer.AsStream()); + var server = McpServerFactory.Create(transport, options, loggerFactory, _serviceProvider); - var serverStdinWriter = new StreamWriter(stdinPipe.Writer.AsStream()); - var serverStdoutReader = new StreamReader(stdoutPipe.Reader.AsStream()); + await server.RunAsync(async cancellationToken => + { + using var serverStdinWriter = new StreamWriter(stdinPipe.Writer.AsStream()); + using var serverStdoutReader = new StreamReader(stdoutPipe.Reader.AsStream()); var serverConfig = new McpServerConfig() { @@ -125,9 +130,9 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T var client = await McpClientFactory.CreateAsync( serverConfig, createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader), - cancellationToken: TestContext.Current.CancellationToken); + cancellationToken: cancellationToken); - var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); + var tools = await client.ListToolsAsync(cancellationToken); Assert.Equal(11, tools.Count); McpClientTool echoTool = tools.First(t => t.Name == "Echo"); @@ -141,12 +146,7 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T McpClientTool doubleEchoTool = tools.First(t => t.Name == "double_echo"); Assert.Equal("double_echo", doubleEchoTool.Name); Assert.Equal("Echoes the input back to the client.", doubleEchoTool.Description); - } - finally - { - stdinPipe.Writer.Complete(); - stdoutPipe.Writer.Complete(); - } + }, TestContext.Current.CancellationToken); } } diff --git a/tests/ModelContextProtocol.Tests/EverythingSseServerFixture.cs b/tests/ModelContextProtocol.Tests/EverythingSseServerFixture.cs index fdd701d7..d5fff66f 100644 --- a/tests/ModelContextProtocol.Tests/EverythingSseServerFixture.cs +++ b/tests/ModelContextProtocol.Tests/EverythingSseServerFixture.cs @@ -56,8 +56,6 @@ public async ValueTask DisposeAsync() // Log the exception but don't throw await Console.Error.WriteLineAsync($"Error stopping Docker container: {ex.Message}"); } - - GC.SuppressFinalize(this); } private static bool CheckIsDockerAvailable() diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 3f4dd1d8..469d0afc 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -84,51 +84,38 @@ public async Task Constructor_Does_Not_Throw_For_Null_ServiceProvider() } [Fact] - public async Task StartAsync_Should_Throw_InvalidOperationException_If_Already_Initializing() + public async Task RunAsync_Should_Throw_InvalidOperationException_If_Already_Running() { // Arrange await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, _serviceProvider); - var task = server.StartAsync(TestContext.Current.CancellationToken); + var task = server.RunAsync(cancellationToken: TestContext.Current.CancellationToken); // Act & Assert - await Assert.ThrowsAsync(() => server.StartAsync(TestContext.Current.CancellationToken)); + await Assert.ThrowsAsync(() => server.RunAsync(cancellationToken: TestContext.Current.CancellationToken)); await task; } [Fact] - public async Task StartAsync_Should_Do_Nothing_If_Already_Initialized() - { - // Arrange - await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, _serviceProvider); - SetInitialized(server, true); - - await server.StartAsync(TestContext.Current.CancellationToken); - - // Assert - _serverTransport.Verify(t => t.StartListeningAsync(It.IsAny()), Times.Never); - } - - [Fact] - public async Task StartAsync_ShouldStartListening() + public async Task RunAsync_ShouldStartListening() { // Arrange await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, _serviceProvider); // Act - await server.StartAsync(TestContext.Current.CancellationToken); + await server.RunAsync(cancellationToken: TestContext.Current.CancellationToken); // Assert _serverTransport.Verify(t => t.StartListeningAsync(It.IsAny()), Times.Once); } [Fact] - public async Task StartAsync_Sets_Initialized_After_Transport_Responses_Initialized_Notification() + public async Task RunAsync_Sets_Initialized_After_Transport_Responses_Initialized_Notification() { await using var transport = new TestServerTransport(); await using var server = McpServerFactory.Create(transport, _options, LoggerFactory, _serviceProvider); - await server.StartAsync(TestContext.Current.CancellationToken); + await server.RunAsync(cancellationToken: TestContext.Current.CancellationToken); // Send initialized notification await transport.SendMessageAsync(new JsonRpcNotification @@ -137,8 +124,6 @@ await transport.SendMessageAsync(new JsonRpcNotification }, TestContext.Current.CancellationToken); await Task.Delay(50, TestContext.Current.CancellationToken); - - Assert.True(server.IsInitialized); } [Fact] @@ -162,7 +147,7 @@ public async Task RequestSamplingAsync_Should_SendRequest() await using var server = McpServerFactory.Create(transport, _options, LoggerFactory, _serviceProvider); SetClientCapabilities(server, new ClientCapabilities { Sampling = new SamplingCapability() }); - await server.StartAsync(TestContext.Current.CancellationToken); + await server.RunAsync(cancellationToken: TestContext.Current.CancellationToken); // Act var result = await server.RequestSamplingAsync(new CreateMessageRequestParams { Messages = [] }, CancellationToken.None); @@ -191,7 +176,7 @@ public async Task RequestRootsAsync_Should_SendRequest() await using var transport = new TestServerTransport(); await using var server = McpServerFactory.Create(transport, _options, LoggerFactory, _serviceProvider); SetClientCapabilities(server, new ClientCapabilities { Roots = new RootsCapability() }); - await server.StartAsync(TestContext.Current.CancellationToken); + await server.RunAsync(cancellationToken: TestContext.Current.CancellationToken); // Act var result = await server.RequestRootsAsync(new ListRootsRequestParams(), CancellationToken.None); @@ -555,7 +540,7 @@ private async Task Can_Handle_Requests(ServerCapabilities? serverCapabilities, s await using var server = McpServerFactory.Create(transport, options, LoggerFactory, _serviceProvider); - await server.StartAsync(); + await server.RunAsync(); var receivedMessage = new TaskCompletionSource(); @@ -633,13 +618,6 @@ private static void SetClientCapabilities(IMcpServer server, ClientCapabilities property.SetValue(server, capabilities); } - private static void SetInitialized(IMcpServer server, bool isInitialized) - { - PropertyInfo? property = server.GetType().GetProperty("IsInitialized", BindingFlags.Public | BindingFlags.Instance); - Assert.NotNull(property); - property.SetValue(server, isInitialized); - } - private sealed class TestServerForIChatClient(bool supportsSampling) : IMcpServer { public ClientCapabilities? ClientCapabilities => @@ -675,8 +653,6 @@ public Task SendRequestAsync(JsonRpcRequest request, CancellationToken can public ValueTask DisposeAsync() => default; - public bool IsInitialized => throw new NotImplementedException(); - public Implementation? ClientInfo => throw new NotImplementedException(); public McpServerOptions ServerOptions => throw new NotImplementedException(); public IServiceProvider? Services => throw new NotImplementedException(); @@ -684,9 +660,7 @@ public void AddNotificationHandler(string method, Func throw new NotImplementedException(); - public Task StartAsync(CancellationToken cancellationToken = default) => + public Task RunAsync(Func? onInitialized = null, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - - public object? GetService(Type serviceType, object? serviceKey = null) => null; } } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index 1ce4b36c..3f066dd5 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -1,5 +1,4 @@ -using Microsoft.Extensions.AI; -using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using Moq; diff --git a/tests/ModelContextProtocol.Tests/TestAttributes.cs b/tests/ModelContextProtocol.Tests/TestAttributes.cs new file mode 100644 index 00000000..4edbce6e --- /dev/null +++ b/tests/ModelContextProtocol.Tests/TestAttributes.cs @@ -0,0 +1,2 @@ +// Uncomment to disable parallel test execution +//[assembly: CollectionBehavior(DisableTestParallelization = true)] \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs index 761c55e7..84f848fa 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs @@ -90,7 +90,7 @@ public async Task ConnectAsync_Should_Connect_Successfully() if (!firstCall) { Assert.True(transport.IsConnected); - await transport.CloseAsync(); + await transport.DisposeAsync(); } firstCall = false; @@ -148,7 +148,7 @@ public async Task ConnectAsync_Throws_If_Already_Connected() var exception = await Assert.ThrowsAsync(action); Assert.Equal("Transport is already connected", exception.Message); tcsDone.SetResult(); - await transport.CloseAsync(); + await transport.DisposeAsync(); await task; } @@ -305,7 +305,7 @@ public async Task CloseAsync_Should_Dispose_Resources() { await using var transport = new SseClientTransport(_transportOptions, _serverConfig, LoggerFactory); - await transport.CloseAsync(); + await transport.DisposeAsync(); Assert.False(transport.IsConnected); } diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs index 94ae6272..03994ee7 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs @@ -1,4 +1,5 @@ -using ModelContextProtocol.Protocol.Messages; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; @@ -45,9 +46,10 @@ public void Constructor_Throws_For_Null_Options() { Assert.Throws("serverName", () => new StdioServerTransport((string)null!)); + Assert.Throws("serverOptions", () => new StdioServerTransport((IOptions)null!)); Assert.Throws("serverOptions", () => new StdioServerTransport((McpServerOptions)null!)); Assert.Throws("serverOptions.ServerInfo", () => new StdioServerTransport(new McpServerOptions() { ServerInfo = null! })); - Assert.Throws("serverName", () => new StdioServerTransport(new McpServerOptions() { ServerInfo = new() { Name = null!, Version = "" } })); + Assert.Throws("serverOptions.ServerInfo.Name", () => new StdioServerTransport(new McpServerOptions() { ServerInfo = new() { Name = null!, Version = "" } })); } [Fact] diff --git a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs index f38d933e..4fca42a3 100644 --- a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs +++ b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs @@ -12,6 +12,8 @@ public class TestServerTransport : IServerTransport public bool IsConnected => _isStarted; + public Task Completion => Task.CompletedTask; + public ChannelReader MessageReader => _messageChannel; public List SentMessages { get; } = [];