Skip to content

Enable graceful shutdown of servers #122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,7 @@ McpServerOptions options = new()
};

await using IMcpServer server = McpServerFactory.Create(new StdioServerTransport("MyServer"), options);

await server.StartAsync();

// Run until process is stopped by the client (parent process)
await Task.Delay(Timeout.Infinite);
await server.RunAsync();
```

## Acknowledgements
Expand Down
11 changes: 4 additions & 7 deletions samples/AspNetCoreSseServer/McpEndpointRouteBuilderExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ public static class McpEndpointRouteBuilderExtensions
{
public static IEndpointConventionBuilder MapMcpSse(this IEndpointRouteBuilder endpoints)
{
IMcpServer? server = null;
SseResponseStreamTransport? transport = null;
var loggerFactory = endpoints.ServiceProvider.GetRequiredService<ILoggerFactory>();
var mcpServerOptions = endpoints.ServiceProvider.GetRequiredService<IOptions<McpServerOptions>>();
Expand All @@ -19,17 +18,15 @@ public static IEndpointConventionBuilder MapMcpSse(this IEndpointRouteBuilder en

routeGroup.MapGet("/sse", async (HttpResponse response, CancellationToken requestAborted) =>
{
await using var localTransport = transport = new SseResponseStreamTransport(response.Body);
await using var localServer = server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider);

await localServer.StartAsync(requestAborted);

response.Headers.ContentType = "text/event-stream";
response.Headers.CacheControl = "no-cache";

await using var localTransport = transport = new SseResponseStreamTransport(response.Body);
await using var server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to call IMcpServer.RunAsync() somewhere for this sample to continue to work? You can test it out using npx @modelcontextprotocol/inspector


try
{
await transport.RunAsync(requestAborted);
await transport.RunAsync(cancellationToken: requestAborted);
}
catch (OperationCanceledException) when (requestAborted.IsCancellationRequested)
{
Expand Down
122 changes: 52 additions & 70 deletions src/ModelContextProtocol/Client/McpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
using ModelContextProtocol.Shared;
using ModelContextProtocol.Utils.Json;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using System.Text.Json;

namespace ModelContextProtocol.Client;
Expand All @@ -17,7 +16,7 @@ internal sealed class McpClient : McpJsonRpcEndpoint, IMcpClient
private readonly McpClientOptions _options;
private readonly IClientTransport _clientTransport;

private volatile bool _isInitializing;
private int _connecting;

/// <summary>
/// Initializes a new instance of the <see cref="McpClient"/> class.
Expand Down Expand Up @@ -74,92 +73,75 @@ public McpClient(IClientTransport transport, McpClientOptions options, McpServer
/// <inheritdoc/>
public async Task ConnectAsync(CancellationToken cancellationToken = default)
{
if (IsInitialized)
{
_logger.ClientAlreadyInitialized(EndpointName);
return;
}

if (_isInitializing)
if (Interlocked.Exchange(ref _connecting, 1) != 0)
{
_logger.ClientAlreadyInitializing(EndpointName);
throw new InvalidOperationException("Client is already initializing");
throw new InvalidOperationException("Client is already in use.");
}

_isInitializing = true;
CancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
cancellationToken = CancellationTokenSource.Token;

try
{
CancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);

// Connect transport
await _clientTransport.ConnectAsync(CancellationTokenSource.Token).ConfigureAwait(false);
await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false);

// Start processing messages
MessageProcessingTask = ProcessMessagesAsync(CancellationTokenSource.Token);
MessageProcessingTask = ProcessMessagesAsync(cancellationToken);

// Perform initialization sequence
await InitializeAsync(CancellationTokenSource.Token).ConfigureAwait(false);
using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
initializationCts.CancelAfter(_options.InitializationTimeout);

IsInitialized = true;
try
{
// Send initialize request
var initializeResponse = await SendRequestAsync<InitializeResult>(
new JsonRpcRequest
{
Method = "initialize",
Params = new
{
protocolVersion = _options.ProtocolVersion,
capabilities = _options.Capabilities ?? new ClientCapabilities(),
clientInfo = _options.ClientInfo
}
},
initializationCts.Token).ConfigureAwait(false);

// Store server information
_logger.ServerCapabilitiesReceived(EndpointName,
capabilities: JsonSerializer.Serialize(initializeResponse.Capabilities, McpJsonUtilities.JsonContext.Default.ServerCapabilities),
serverInfo: JsonSerializer.Serialize(initializeResponse.ServerInfo, McpJsonUtilities.JsonContext.Default.Implementation));

ServerCapabilities = initializeResponse.Capabilities;
ServerInfo = initializeResponse.ServerInfo;
ServerInstructions = initializeResponse.Instructions;

// Validate protocol version
if (initializeResponse.ProtocolVersion != _options.ProtocolVersion)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR - but we need to figure out what "supports" means exactly in https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/lifecycle/

We might risk disconnecting from servers we are compatible with.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you ensure we have an open issue for that? Thanks.

{
_logger.ServerProtocolVersionMismatch(EndpointName, _options.ProtocolVersion, initializeResponse.ProtocolVersion);
throw new McpClientException($"Server protocol version mismatch. Expected {_options.ProtocolVersion}, got {initializeResponse.ProtocolVersion}");
}

// Send initialized notification
await SendMessageAsync(
new JsonRpcNotification { Method = "notifications/initialized" },
initializationCts.Token).ConfigureAwait(false);
}
catch (OperationCanceledException) when (initializationCts.IsCancellationRequested)
{
_logger.ClientInitializationTimeout(EndpointName);
throw new McpClientException("Initialization timed out");
}
}
catch (Exception e)
{
_logger.ClientInitializationError(EndpointName, e);
await CleanupAsync().ConfigureAwait(false);
throw;
}
finally
{
_isInitializing = false;
}
}

private async Task InitializeAsync(CancellationToken cancellationToken)
{
using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
initializationCts.CancelAfter(_options.InitializationTimeout);

try
{
// Send initialize request
var initializeResponse = await SendRequestAsync<InitializeResult>(
new JsonRpcRequest
{
Method = "initialize",
Params = new
{
protocolVersion = _options.ProtocolVersion,
capabilities = _options.Capabilities ?? new ClientCapabilities(),
clientInfo = _options.ClientInfo
}
},
initializationCts.Token).ConfigureAwait(false);

// Store server information
_logger.ServerCapabilitiesReceived(EndpointName,
capabilities: JsonSerializer.Serialize(initializeResponse.Capabilities, McpJsonUtilities.JsonContext.Default.ServerCapabilities),
serverInfo: JsonSerializer.Serialize(initializeResponse.ServerInfo, McpJsonUtilities.JsonContext.Default.Implementation));

ServerCapabilities = initializeResponse.Capabilities;
ServerInfo = initializeResponse.ServerInfo;
ServerInstructions = initializeResponse.Instructions;

// Validate protocol version
if (initializeResponse.ProtocolVersion != _options.ProtocolVersion)
{
_logger.ServerProtocolVersionMismatch(EndpointName, _options.ProtocolVersion, initializeResponse.ProtocolVersion);
throw new McpClientException($"Server protocol version mismatch. Expected {_options.ProtocolVersion}, got {initializeResponse.ProtocolVersion}");
}

// Send initialized notification
await SendMessageAsync(
new JsonRpcNotification { Method = "notifications/initialized" },
initializationCts.Token).ConfigureAwait(false);
}
catch (OperationCanceledException) when (initializationCts.IsCancellationRequested)
{
_logger.ClientInitializationTimeout(EndpointName);
throw new McpClientException("Initialization timed out");
}
}
}
2 changes: 1 addition & 1 deletion src/ModelContextProtocol/Hosting/McpServerHostedService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ public McpServerHostedService(IMcpServer server)
/// <inheritdoc />
protected override async Task ExecuteAsync(CancellationToken stoppingToken)
{
await _server.StartAsync(stoppingToken).ConfigureAwait(false);
await _server.RunAsync(cancellationToken: stoppingToken).ConfigureAwait(false);
}
}
Loading
Loading