Skip to content

Refactor transports to help enable graceful shutdown #142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Mar 31, 2025
6 changes: 1 addition & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,7 @@ McpServerOptions options = new()
};

await using IMcpServer server = McpServerFactory.Create(new StdioServerTransport("MyServer"), options);

await server.StartAsync();

// Run until process is stopped by the client (parent process)
await Task.Delay(Timeout.Infinite);
await server.RunAsync();
```

## Acknowledgements
Expand Down
13 changes: 6 additions & 7 deletions samples/AspNetCoreSseServer/McpEndpointRouteBuilderExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ public static class McpEndpointRouteBuilderExtensions
{
public static IEndpointConventionBuilder MapMcpSse(this IEndpointRouteBuilder endpoints)
{
IMcpServer? server = null;
SseResponseStreamTransport? transport = null;
var loggerFactory = endpoints.ServiceProvider.GetRequiredService<ILoggerFactory>();
var mcpServerOptions = endpoints.ServiceProvider.GetRequiredService<IOptions<McpServerOptions>>();
Expand All @@ -19,17 +18,17 @@ public static IEndpointConventionBuilder MapMcpSse(this IEndpointRouteBuilder en

routeGroup.MapGet("/sse", async (HttpResponse response, CancellationToken requestAborted) =>
{
await using var localTransport = transport = new SseResponseStreamTransport(response.Body);
await using var localServer = server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider);

await localServer.StartAsync(requestAborted);

response.Headers.ContentType = "text/event-stream";
response.Headers.CacheControl = "no-cache";

await using var localTransport = transport = new SseResponseStreamTransport(response.Body);
await using var server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider);

try
{
await transport.RunAsync(requestAborted);
var transportTask = transport.RunAsync(cancellationToken: requestAborted);
await server.RunAsync(cancellationToken: requestAborted);
await transportTask;
}
catch (OperationCanceledException) when (requestAborted.IsCancellationRequested)
{
Expand Down
150 changes: 74 additions & 76 deletions src/ModelContextProtocol/Client/McpClient.cs
Original file line number Diff line number Diff line change
@@ -1,35 +1,37 @@
using ModelContextProtocol.Configuration;
using Microsoft.Extensions.Logging;
using ModelContextProtocol.Configuration;
using ModelContextProtocol.Logging;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Shared;
using ModelContextProtocol.Utils.Json;
using Microsoft.Extensions.Logging;
using System.Text.Json;

namespace ModelContextProtocol.Client;

/// <inheritdoc/>
internal sealed class McpClient : McpJsonRpcEndpoint, IMcpClient
{
private readonly McpClientOptions _options;
private readonly IClientTransport _clientTransport;
private readonly McpClientOptions _options;

private volatile bool _isInitializing;
private ITransport? _sessionTransport;
private CancellationTokenSource? _connectCts;
private int _disposed;

/// <summary>
/// Initializes a new instance of the <see cref="McpClient"/> class.
/// </summary>
/// <param name="transport">The transport to use for communication with the server.</param>
/// <param name="clientTransport">The transport to use for communication with the server.</param>
/// <param name="options">Options for the client, defining protocol version and capabilities.</param>
/// <param name="serverConfig">The server configuration.</param>
/// <param name="loggerFactory">The logger factory.</param>
public McpClient(IClientTransport transport, McpClientOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory)
: base(transport, loggerFactory)
public McpClient(IClientTransport clientTransport, McpClientOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory)
: base(loggerFactory)
{
_clientTransport = clientTransport;
_options = options;
_clientTransport = transport;

EndpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})";

Expand Down Expand Up @@ -70,95 +72,91 @@ public McpClient(IClientTransport transport, McpClientOptions options, McpServer
/// <inheritdoc/>
public override string EndpointName { get; }

/// <inheritdoc/>
public async Task ConnectAsync(CancellationToken cancellationToken = default)
{
if (IsInitialized)
{
_logger.ClientAlreadyInitialized(EndpointName);
return;
}

if (_isInitializing)
{
_logger.ClientAlreadyInitializing(EndpointName);
throw new InvalidOperationException("Client is already initializing");
}
_connectCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
cancellationToken = _connectCts.Token;

_isInitializing = true;
try
{
CancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);

// Connect transport
await _clientTransport.ConnectAsync(CancellationTokenSource.Token).ConfigureAwait(false);

// Start processing messages
MessageProcessingTask = ProcessMessagesAsync(CancellationTokenSource.Token);
_sessionTransport = await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false);
InitializeSession(_sessionTransport);
// We don't want the ConnectAsync token to cancel the session after we've successfully connected.
// The base class handles cleaning up the session in DisposeAsync without our help.
StartSession(fullSessionCancellationToken: CancellationToken.None);

// Perform initialization sequence
await InitializeAsync(CancellationTokenSource.Token).ConfigureAwait(false);
using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
initializationCts.CancelAfter(_options.InitializationTimeout);

IsInitialized = true;
try
{
// Send initialize request
var initializeResponse = await SendRequestAsync<InitializeResult>(
new JsonRpcRequest
{
Method = "initialize",
Params = new InitializeRequestParams()
{
ProtocolVersion = _options.ProtocolVersion,
Capabilities = _options.Capabilities ?? new ClientCapabilities(),
ClientInfo = _options.ClientInfo,
}
},
initializationCts.Token).ConfigureAwait(false);

// Store server information
_logger.ServerCapabilitiesReceived(EndpointName,
capabilities: JsonSerializer.Serialize(initializeResponse.Capabilities, McpJsonUtilities.JsonContext.Default.ServerCapabilities),
serverInfo: JsonSerializer.Serialize(initializeResponse.ServerInfo, McpJsonUtilities.JsonContext.Default.Implementation));

ServerCapabilities = initializeResponse.Capabilities;
ServerInfo = initializeResponse.ServerInfo;
ServerInstructions = initializeResponse.Instructions;

// Validate protocol version
if (initializeResponse.ProtocolVersion != _options.ProtocolVersion)
{
_logger.ServerProtocolVersionMismatch(EndpointName, _options.ProtocolVersion, initializeResponse.ProtocolVersion);
throw new McpClientException($"Server protocol version mismatch. Expected {_options.ProtocolVersion}, got {initializeResponse.ProtocolVersion}");
}

// Send initialized notification
await SendMessageAsync(
new JsonRpcNotification { Method = "notifications/initialized" },
initializationCts.Token).ConfigureAwait(false);
}
catch (OperationCanceledException) when (initializationCts.IsCancellationRequested)
{
_logger.ClientInitializationTimeout(EndpointName);
throw new McpClientException("Initialization timed out");
}
}
catch (Exception e)
{
_logger.ClientInitializationError(EndpointName, e);
await CleanupAsync().ConfigureAwait(false);
await DisposeAsync().ConfigureAwait(false);
throw;
}
finally
{
_isInitializing = false;
}
}

private async Task InitializeAsync(CancellationToken cancellationToken)
/// <inheritdoc/>
public override async ValueTask DisposeAsync()
{
using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
initializationCts.CancelAfter(_options.InitializationTimeout);

try
if (Interlocked.Exchange(ref _disposed, 1) != 0)
{
// Send initialize request
var initializeResponse = await SendRequestAsync<InitializeResult>(
new JsonRpcRequest
{
Method = "initialize",
Params = new InitializeRequestParams()
{
ProtocolVersion = _options.ProtocolVersion,
Capabilities = _options.Capabilities ?? new ClientCapabilities(),
ClientInfo = _options.ClientInfo
}
},
initializationCts.Token).ConfigureAwait(false);

// Store server information
_logger.ServerCapabilitiesReceived(EndpointName,
capabilities: JsonSerializer.Serialize(initializeResponse.Capabilities, McpJsonUtilities.JsonContext.Default.ServerCapabilities),
serverInfo: JsonSerializer.Serialize(initializeResponse.ServerInfo, McpJsonUtilities.JsonContext.Default.Implementation));

ServerCapabilities = initializeResponse.Capabilities;
ServerInfo = initializeResponse.ServerInfo;
ServerInstructions = initializeResponse.Instructions;

// Validate protocol version
if (initializeResponse.ProtocolVersion != _options.ProtocolVersion)
{
_logger.ServerProtocolVersionMismatch(EndpointName, _options.ProtocolVersion, initializeResponse.ProtocolVersion);
throw new McpClientException($"Server protocol version mismatch. Expected {_options.ProtocolVersion}, got {initializeResponse.ProtocolVersion}");
}

// Send initialized notification
await SendMessageAsync(
new JsonRpcNotification { Method = "notifications/initialized" },
initializationCts.Token).ConfigureAwait(false);
// TODO: It's more correct to await the last DisposeAsync before returning if it's still ongoing.
return;
}
catch (OperationCanceledException) when (initializationCts.IsCancellationRequested)

if (_connectCts is not null)
{
_logger.ClientInitializationTimeout(EndpointName);
throw new McpClientException("Initialization timed out");
await _connectCts.CancelAsync().ConfigureAwait(false);
}

await base.DisposeAsync().ConfigureAwait(false);

_connectCts?.Dispose();
}
}
18 changes: 5 additions & 13 deletions src/ModelContextProtocol/Client/McpClientFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,24 +65,16 @@ public static async Task<IMcpClient> CreateAsync(
createTransportFunc(serverConfig, loggerFactory) ??
throw new InvalidOperationException($"{nameof(createTransportFunc)} returned a null transport.");

McpClient client = new(transport, clientOptions, serverConfig, loggerFactory);
try
{
McpClient client = new(transport, clientOptions, serverConfig, loggerFactory);
try
{
await client.ConnectAsync(cancellationToken).ConfigureAwait(false);
logger.ClientCreated(endpointName);
return client;
}
catch
{
await client.DisposeAsync().ConfigureAwait(false);
throw;
}
await client.ConnectAsync(cancellationToken).ConfigureAwait(false);
logger.ClientCreated(endpointName);
return client;
}
catch
{
await transport.DisposeAsync().ConfigureAwait(false);
await client.DisposeAsync().ConfigureAwait(false);
throw;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Utils;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using ModelContextProtocol.Server;

namespace ModelContextProtocol;

Expand All @@ -19,8 +22,18 @@ public static IMcpServerBuilder WithStdioServerTransport(this IMcpServerBuilder
{
Throw.IfNull(builder);

builder.Services.AddSingleton<IServerTransport, StdioServerTransport>();
builder.Services.AddHostedService<McpServerHostedService>();
builder.Services.AddSingleton<ITransport, StdioServerTransport>();
builder.Services.AddHostedService<McpServerSingleSessionHostedService>();

builder.Services.AddSingleton(services =>
{
ITransport serverTransport = services.GetRequiredService<ITransport>();
IOptions<McpServerOptions> options = services.GetRequiredService<IOptions<McpServerOptions>>();
ILoggerFactory? loggerFactory = services.GetService<ILoggerFactory>();

return McpServerFactory.Create(serverTransport, options.Value, loggerFactory, services);
});

return builder;
}

Expand All @@ -33,7 +46,7 @@ public static IMcpServerBuilder WithHttpListenerSseServerTransport(this IMcpServ
Throw.IfNull(builder);

builder.Services.AddSingleton<IServerTransport, HttpListenerSseServerTransport>();
builder.Services.AddHostedService<McpServerHostedService>();
builder.Services.AddHostedService<McpServerMultiSessionHostedService>();
return builder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
namespace ModelContextProtocol;

/// <summary>
/// Extension to host the MCP server
/// Extension to host an MCP server
/// </summary>
public static class McpServerServiceCollectionExtension
{
Expand All @@ -20,15 +20,6 @@ public static class McpServerServiceCollectionExtension
/// <returns></returns>
public static IMcpServerBuilder AddMcpServer(this IServiceCollection services, Action<McpServerOptions>? configureOptions = null)
{
services.AddSingleton(services =>
{
IServerTransport serverTransport = services.GetRequiredService<IServerTransport>();
IOptions<McpServerOptions> options = services.GetRequiredService<IOptions<McpServerOptions>>();
ILoggerFactory? loggerFactory = services.GetService<ILoggerFactory>();

return McpServerFactory.Create(serverTransport, options.Value, loggerFactory, services);
});

services.AddOptions();
services.AddTransient<IConfigureOptions<McpServerOptions>, McpServerOptionsSetup>();
if (configureOptions is not null)
Expand Down
31 changes: 0 additions & 31 deletions src/ModelContextProtocol/Hosting/McpServerHostedService.cs

This file was deleted.

Loading