Skip to content

Commit 2ab9db0

Browse files
Refactor transports to help enable graceful shutdown (#142)
- This also paves the way for better multi-session support - We should definitely rethink names for the transport API - For now, I kept the names similar as possible, so we can focus on the API shape Co-authored-by: Stephen Toub <[email protected]>
1 parent d47c834 commit 2ab9db0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+2015
-1847
lines changed

README.md

+1-5
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,7 @@ McpServerOptions options = new()
198198
};
199199

200200
await using IMcpServer server = McpServerFactory.Create(new StdioServerTransport("MyServer"), options);
201-
202-
await server.StartAsync();
203-
204-
// Run until process is stopped by the client (parent process)
205-
await Task.Delay(Timeout.Infinite);
201+
await server.RunAsync();
206202
```
207203

208204
## Acknowledgements

samples/AspNetCoreSseServer/McpEndpointRouteBuilderExtensions.cs

+6-7
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ public static class McpEndpointRouteBuilderExtensions
1010
{
1111
public static IEndpointConventionBuilder MapMcpSse(this IEndpointRouteBuilder endpoints)
1212
{
13-
IMcpServer? server = null;
1413
SseResponseStreamTransport? transport = null;
1514
var loggerFactory = endpoints.ServiceProvider.GetRequiredService<ILoggerFactory>();
1615
var mcpServerOptions = endpoints.ServiceProvider.GetRequiredService<IOptions<McpServerOptions>>();
@@ -19,17 +18,17 @@ public static IEndpointConventionBuilder MapMcpSse(this IEndpointRouteBuilder en
1918

2019
routeGroup.MapGet("/sse", async (HttpResponse response, CancellationToken requestAborted) =>
2120
{
22-
await using var localTransport = transport = new SseResponseStreamTransport(response.Body);
23-
await using var localServer = server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider);
24-
25-
await localServer.StartAsync(requestAborted);
26-
2721
response.Headers.ContentType = "text/event-stream";
2822
response.Headers.CacheControl = "no-cache";
2923

24+
await using var localTransport = transport = new SseResponseStreamTransport(response.Body);
25+
await using var server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider);
26+
3027
try
3128
{
32-
await transport.RunAsync(requestAborted);
29+
var transportTask = transport.RunAsync(cancellationToken: requestAborted);
30+
await server.RunAsync(cancellationToken: requestAborted);
31+
await transportTask;
3332
}
3433
catch (OperationCanceledException) when (requestAborted.IsCancellationRequested)
3534
{
+74-73
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,36 @@
1-
using ModelContextProtocol.Configuration;
1+
using Microsoft.Extensions.Logging;
2+
using ModelContextProtocol.Configuration;
23
using ModelContextProtocol.Logging;
34
using ModelContextProtocol.Protocol.Messages;
45
using ModelContextProtocol.Protocol.Transport;
56
using ModelContextProtocol.Protocol.Types;
67
using ModelContextProtocol.Shared;
78
using ModelContextProtocol.Utils.Json;
8-
using Microsoft.Extensions.Logging;
99
using System.Text.Json;
1010

1111
namespace ModelContextProtocol.Client;
1212

1313
/// <inheritdoc/>
1414
internal sealed class McpClient : McpJsonRpcEndpoint, IMcpClient
1515
{
16-
private readonly McpClientOptions _options;
1716
private readonly IClientTransport _clientTransport;
17+
private readonly McpClientOptions _options;
1818

19-
private volatile bool _isInitializing;
19+
private ITransport? _sessionTransport;
20+
private CancellationTokenSource? _connectCts;
2021

2122
/// <summary>
2223
/// Initializes a new instance of the <see cref="McpClient"/> class.
2324
/// </summary>
24-
/// <param name="transport">The transport to use for communication with the server.</param>
25+
/// <param name="clientTransport">The transport to use for communication with the server.</param>
2526
/// <param name="options">Options for the client, defining protocol version and capabilities.</param>
2627
/// <param name="serverConfig">The server configuration.</param>
2728
/// <param name="loggerFactory">The logger factory.</param>
28-
public McpClient(IClientTransport transport, McpClientOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory)
29-
: base(transport, loggerFactory)
29+
public McpClient(IClientTransport clientTransport, McpClientOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory)
30+
: base(loggerFactory)
3031
{
32+
_clientTransport = clientTransport;
3133
_options = options;
32-
_clientTransport = transport;
3334

3435
EndpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})";
3536

@@ -70,95 +71,95 @@ public McpClient(IClientTransport transport, McpClientOptions options, McpServer
7071
/// <inheritdoc/>
7172
public override string EndpointName { get; }
7273

73-
/// <inheritdoc/>
7474
public async Task ConnectAsync(CancellationToken cancellationToken = default)
7575
{
76-
if (IsInitialized)
77-
{
78-
_logger.ClientAlreadyInitialized(EndpointName);
79-
return;
80-
}
81-
82-
if (_isInitializing)
83-
{
84-
_logger.ClientAlreadyInitializing(EndpointName);
85-
throw new InvalidOperationException("Client is already initializing");
86-
}
76+
_connectCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
77+
cancellationToken = _connectCts.Token;
8778

88-
_isInitializing = true;
8979
try
9080
{
91-
CancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
92-
9381
// Connect transport
94-
await _clientTransport.ConnectAsync(CancellationTokenSource.Token).ConfigureAwait(false);
95-
96-
// Start processing messages
97-
MessageProcessingTask = ProcessMessagesAsync(CancellationTokenSource.Token);
82+
_sessionTransport = await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false);
83+
InitializeSession(_sessionTransport);
84+
// We don't want the ConnectAsync token to cancel the session after we've successfully connected.
85+
// The base class handles cleaning up the session in DisposeAsync without our help.
86+
StartSession(fullSessionCancellationToken: CancellationToken.None);
9887

9988
// Perform initialization sequence
100-
await InitializeAsync(CancellationTokenSource.Token).ConfigureAwait(false);
89+
using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
90+
initializationCts.CancelAfter(_options.InitializationTimeout);
10191

102-
IsInitialized = true;
92+
try
93+
{
94+
// Send initialize request
95+
var initializeResponse = await SendRequestAsync<InitializeResult>(
96+
new JsonRpcRequest
97+
{
98+
Method = "initialize",
99+
Params = new InitializeRequestParams()
100+
{
101+
ProtocolVersion = _options.ProtocolVersion,
102+
Capabilities = _options.Capabilities ?? new ClientCapabilities(),
103+
ClientInfo = _options.ClientInfo,
104+
}
105+
},
106+
initializationCts.Token).ConfigureAwait(false);
107+
108+
// Store server information
109+
_logger.ServerCapabilitiesReceived(EndpointName,
110+
capabilities: JsonSerializer.Serialize(initializeResponse.Capabilities, McpJsonUtilities.JsonContext.Default.ServerCapabilities),
111+
serverInfo: JsonSerializer.Serialize(initializeResponse.ServerInfo, McpJsonUtilities.JsonContext.Default.Implementation));
112+
113+
ServerCapabilities = initializeResponse.Capabilities;
114+
ServerInfo = initializeResponse.ServerInfo;
115+
ServerInstructions = initializeResponse.Instructions;
116+
117+
// Validate protocol version
118+
if (initializeResponse.ProtocolVersion != _options.ProtocolVersion)
119+
{
120+
_logger.ServerProtocolVersionMismatch(EndpointName, _options.ProtocolVersion, initializeResponse.ProtocolVersion);
121+
throw new McpClientException($"Server protocol version mismatch. Expected {_options.ProtocolVersion}, got {initializeResponse.ProtocolVersion}");
122+
}
123+
124+
// Send initialized notification
125+
await SendMessageAsync(
126+
new JsonRpcNotification { Method = "notifications/initialized" },
127+
initializationCts.Token).ConfigureAwait(false);
128+
}
129+
catch (OperationCanceledException) when (initializationCts.IsCancellationRequested)
130+
{
131+
_logger.ClientInitializationTimeout(EndpointName);
132+
throw new McpClientException("Initialization timed out");
133+
}
103134
}
104135
catch (Exception e)
105136
{
106137
_logger.ClientInitializationError(EndpointName, e);
107-
await CleanupAsync().ConfigureAwait(false);
138+
await DisposeAsync().ConfigureAwait(false);
108139
throw;
109140
}
110-
finally
111-
{
112-
_isInitializing = false;
113-
}
114141
}
115142

116-
private async Task InitializeAsync(CancellationToken cancellationToken)
143+
/// <inheritdoc/>
144+
public override async ValueTask DisposeUnsynchronizedAsync()
117145
{
118-
using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
119-
initializationCts.CancelAfter(_options.InitializationTimeout);
146+
if (_connectCts is not null)
147+
{
148+
await _connectCts.CancelAsync().ConfigureAwait(false);
149+
}
120150

121151
try
122152
{
123-
// Send initialize request
124-
var initializeResponse = await SendRequestAsync<InitializeResult>(
125-
new JsonRpcRequest
126-
{
127-
Method = "initialize",
128-
Params = new InitializeRequestParams()
129-
{
130-
ProtocolVersion = _options.ProtocolVersion,
131-
Capabilities = _options.Capabilities ?? new ClientCapabilities(),
132-
ClientInfo = _options.ClientInfo
133-
}
134-
},
135-
initializationCts.Token).ConfigureAwait(false);
136-
137-
// Store server information
138-
_logger.ServerCapabilitiesReceived(EndpointName,
139-
capabilities: JsonSerializer.Serialize(initializeResponse.Capabilities, McpJsonUtilities.JsonContext.Default.ServerCapabilities),
140-
serverInfo: JsonSerializer.Serialize(initializeResponse.ServerInfo, McpJsonUtilities.JsonContext.Default.Implementation));
141-
142-
ServerCapabilities = initializeResponse.Capabilities;
143-
ServerInfo = initializeResponse.ServerInfo;
144-
ServerInstructions = initializeResponse.Instructions;
145-
146-
// Validate protocol version
147-
if (initializeResponse.ProtocolVersion != _options.ProtocolVersion)
153+
await base.DisposeUnsynchronizedAsync().ConfigureAwait(false);
154+
}
155+
finally
156+
{
157+
if (_sessionTransport is not null)
148158
{
149-
_logger.ServerProtocolVersionMismatch(EndpointName, _options.ProtocolVersion, initializeResponse.ProtocolVersion);
150-
throw new McpClientException($"Server protocol version mismatch. Expected {_options.ProtocolVersion}, got {initializeResponse.ProtocolVersion}");
159+
await _sessionTransport.DisposeAsync().ConfigureAwait(false);
151160
}
152161

153-
// Send initialized notification
154-
await SendMessageAsync(
155-
new JsonRpcNotification { Method = "notifications/initialized" },
156-
initializationCts.Token).ConfigureAwait(false);
157-
}
158-
catch (OperationCanceledException) when (initializationCts.IsCancellationRequested)
159-
{
160-
_logger.ClientInitializationTimeout(EndpointName);
161-
throw new McpClientException("Initialization timed out");
162+
_connectCts?.Dispose();
162163
}
163164
}
164165
}

src/ModelContextProtocol/Client/McpClientFactory.cs

+8-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,14 @@ public static async Task<IMcpClient> CreateAsync(
8282
}
8383
catch
8484
{
85-
await transport.DisposeAsync().ConfigureAwait(false);
85+
if (transport is IAsyncDisposable asyncDisposableTransport)
86+
{
87+
await asyncDisposableTransport.DisposeAsync().ConfigureAwait(false);
88+
}
89+
else if (transport is IDisposable disposableTransport)
90+
{
91+
disposableTransport.Dispose();
92+
}
8693
throw;
8794
}
8895
}

src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs

+16-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
using ModelContextProtocol.Protocol.Transport;
44
using ModelContextProtocol.Utils;
55
using Microsoft.Extensions.DependencyInjection;
6+
using Microsoft.Extensions.Logging;
7+
using Microsoft.Extensions.Options;
8+
using ModelContextProtocol.Server;
69

710
namespace ModelContextProtocol;
811

@@ -19,8 +22,18 @@ public static IMcpServerBuilder WithStdioServerTransport(this IMcpServerBuilder
1922
{
2023
Throw.IfNull(builder);
2124

22-
builder.Services.AddSingleton<IServerTransport, StdioServerTransport>();
23-
builder.Services.AddHostedService<McpServerHostedService>();
25+
builder.Services.AddSingleton<ITransport, StdioServerTransport>();
26+
builder.Services.AddHostedService<McpServerSingleSessionHostedService>();
27+
28+
builder.Services.AddSingleton(services =>
29+
{
30+
ITransport serverTransport = services.GetRequiredService<ITransport>();
31+
IOptions<McpServerOptions> options = services.GetRequiredService<IOptions<McpServerOptions>>();
32+
ILoggerFactory? loggerFactory = services.GetService<ILoggerFactory>();
33+
34+
return McpServerFactory.Create(serverTransport, options.Value, loggerFactory, services);
35+
});
36+
2437
return builder;
2538
}
2639

@@ -33,7 +46,7 @@ public static IMcpServerBuilder WithHttpListenerSseServerTransport(this IMcpServ
3346
Throw.IfNull(builder);
3447

3548
builder.Services.AddSingleton<IServerTransport, HttpListenerSseServerTransport>();
36-
builder.Services.AddHostedService<McpServerHostedService>();
49+
builder.Services.AddHostedService<McpServerMultiSessionHostedService>();
3750
return builder;
3851
}
3952
}

src/ModelContextProtocol/Configuration/McpServerServiceCollectionExtension.cs

+1-10
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
namespace ModelContextProtocol;
99

1010
/// <summary>
11-
/// Extension to host the MCP server
11+
/// Extension to host an MCP server
1212
/// </summary>
1313
public static class McpServerServiceCollectionExtension
1414
{
@@ -20,15 +20,6 @@ public static class McpServerServiceCollectionExtension
2020
/// <returns></returns>
2121
public static IMcpServerBuilder AddMcpServer(this IServiceCollection services, Action<McpServerOptions>? configureOptions = null)
2222
{
23-
services.AddSingleton(services =>
24-
{
25-
IServerTransport serverTransport = services.GetRequiredService<IServerTransport>();
26-
IOptions<McpServerOptions> options = services.GetRequiredService<IOptions<McpServerOptions>>();
27-
ILoggerFactory? loggerFactory = services.GetService<ILoggerFactory>();
28-
29-
return McpServerFactory.Create(serverTransport, options.Value, loggerFactory, services);
30-
});
31-
3223
services.AddOptions();
3324
services.AddTransient<IConfigureOptions<McpServerOptions>, McpServerOptionsSetup>();
3425
if (configureOptions is not null)

src/ModelContextProtocol/Hosting/McpServerHostedService.cs

-31
This file was deleted.

0 commit comments

Comments
 (0)