Skip to content

Commit c57d65d

Browse files
committed
Move notification handler registrations to capabilities
Currently request handlers are set on the capability objects, but notification handlers are set after construction via an AddNotificationHandler method on the IMcpEndpoint interface. This moves handler specification to be at construction as well. This makes it more consistent with request handlers, simplifies the IMcpEndpoint interface to just be about message sending, and avoids a concurrency bug that could occur if someone tried to add a handler while the endpoint was processing notifications.
1 parent 3b26bf8 commit c57d65d

16 files changed

+221
-160
lines changed

src/ModelContextProtocol/Client/McpClient.cs

+49-25
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,18 @@
55
using ModelContextProtocol.Protocol.Types;
66
using ModelContextProtocol.Shared;
77
using ModelContextProtocol.Utils.Json;
8+
using System.Diagnostics;
9+
using System.Reflection;
810
using System.Text.Json;
911

1012
namespace ModelContextProtocol.Client;
1113

1214
/// <inheritdoc/>
1315
internal sealed class McpClient : McpEndpoint, IMcpClient
1416
{
17+
/// <summary>Cached naming information used for client name/version when none is specified.</summary>
18+
private static readonly AssemblyName s_asmName = (Assembly.GetEntryAssembly() ?? Assembly.GetExecutingAssembly()).GetName();
19+
1520
private readonly IClientTransport _clientTransport;
1621
private readonly McpClientOptions _options;
1722

@@ -25,43 +30,61 @@ internal sealed class McpClient : McpEndpoint, IMcpClient
2530
/// <param name="options">Options for the client, defining protocol version and capabilities.</param>
2631
/// <param name="serverConfig">The server configuration.</param>
2732
/// <param name="loggerFactory">The logger factory.</param>
28-
public McpClient(IClientTransport clientTransport, McpClientOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory)
33+
public McpClient(IClientTransport clientTransport, McpClientOptions? options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory)
2934
: base(loggerFactory)
3035
{
3136
_clientTransport = clientTransport;
37+
38+
if (options?.ClientInfo is null)
39+
{
40+
options = options?.Clone() ?? new();
41+
options.ClientInfo = new()
42+
{
43+
Name = s_asmName.Name ?? nameof(McpClient),
44+
Version = s_asmName.Version?.ToString() ?? "1.0.0",
45+
};
46+
}
3247
_options = options;
3348

3449
EndpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})";
3550

36-
if (options.Capabilities?.Sampling is { } samplingCapability)
51+
if (options.Capabilities is { } capabilities)
3752
{
38-
if (samplingCapability.SamplingHandler is not { } samplingHandler)
53+
if (capabilities.NotificationHandlers is { } notificationHandlers)
3954
{
40-
throw new InvalidOperationException($"Sampling capability was set but it did not provide a handler.");
55+
AddNotificationHandlers(notificationHandlers);
4156
}
4257

43-
SetRequestHandler(
44-
RequestMethods.SamplingCreateMessage,
45-
(request, cancellationToken) => samplingHandler(
46-
request,
47-
request?.Meta?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance,
48-
cancellationToken),
49-
McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams,
50-
McpJsonUtilities.JsonContext.Default.CreateMessageResult);
51-
}
52-
53-
if (options.Capabilities?.Roots is { } rootsCapability)
54-
{
55-
if (rootsCapability.RootsHandler is not { } rootsHandler)
58+
if (capabilities.Sampling is { } samplingCapability)
5659
{
57-
throw new InvalidOperationException($"Roots capability was set but it did not provide a handler.");
60+
if (samplingCapability.SamplingHandler is not { } samplingHandler)
61+
{
62+
throw new InvalidOperationException($"Sampling capability was set but it did not provide a handler.");
63+
}
64+
65+
SetRequestHandler(
66+
RequestMethods.SamplingCreateMessage,
67+
(request, cancellationToken) => samplingHandler(
68+
request,
69+
request?.Meta?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance,
70+
cancellationToken),
71+
McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams,
72+
McpJsonUtilities.JsonContext.Default.CreateMessageResult);
5873
}
5974

60-
SetRequestHandler(
61-
RequestMethods.RootsList,
62-
rootsHandler,
63-
McpJsonUtilities.JsonContext.Default.ListRootsRequestParams,
64-
McpJsonUtilities.JsonContext.Default.ListRootsResult);
75+
if (capabilities.Roots is { } rootsCapability)
76+
{
77+
if (rootsCapability.RootsHandler is not { } rootsHandler)
78+
{
79+
throw new InvalidOperationException($"Roots capability was set but it did not provide a handler.");
80+
}
81+
82+
SetRequestHandler(
83+
RequestMethods.RootsList,
84+
rootsHandler,
85+
McpJsonUtilities.JsonContext.Default.ListRootsRequestParams,
86+
McpJsonUtilities.JsonContext.Default.ListRootsResult);
87+
}
6588
}
6689
}
6790

@@ -95,13 +118,14 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)
95118
try
96119
{
97120
// Send initialize request
98-
var initializeResponse = await this.SendRequestAsync(
121+
Debug.Assert(_options.ClientInfo is not null, "ClientInfo should be set by the constructor");
122+
var initializeResponse = await this.SendRequestAsync(
99123
RequestMethods.Initialize,
100124
new InitializeRequestParams
101125
{
102126
ProtocolVersion = _options.ProtocolVersion,
103127
Capabilities = _options.Capabilities ?? new ClientCapabilities(),
104-
ClientInfo = _options.ClientInfo
128+
ClientInfo = _options.ClientInfo!
105129
},
106130
McpJsonUtilities.JsonContext.Default.InitializeRequestParams,
107131
McpJsonUtilities.JsonContext.Default.InitializeResult,

src/ModelContextProtocol/Client/McpClientFactory.cs

-18
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,6 @@ namespace ModelContextProtocol.Client;
1212
/// <summary>Provides factory methods for creating MCP clients.</summary>
1313
public static class McpClientFactory
1414
{
15-
/// <summary>Default client options to use when none are supplied.</summary>
16-
private static readonly McpClientOptions s_defaultClientOptions = CreateDefaultClientOptions();
17-
18-
/// <summary>Creates default client options to use when no options are supplied.</summary>
19-
private static McpClientOptions CreateDefaultClientOptions()
20-
{
21-
var asmName = (Assembly.GetEntryAssembly() ?? Assembly.GetExecutingAssembly()).GetName();
22-
return new()
23-
{
24-
ClientInfo = new()
25-
{
26-
Name = asmName.Name ?? "McpClient",
27-
Version = asmName.Version?.ToString() ?? "1.0.0",
28-
},
29-
};
30-
}
31-
3215
/// <summary>Creates an <see cref="IMcpClient"/>, connecting it to the specified server.</summary>
3316
/// <param name="serverConfig">Configuration for the target server to which the client should connect.</param>
3417
/// <param name="clientOptions">
@@ -52,7 +35,6 @@ public static async Task<IMcpClient> CreateAsync(
5235
{
5336
Throw.IfNull(serverConfig);
5437

55-
clientOptions ??= s_defaultClientOptions;
5638
createTransportFunc ??= CreateTransport;
5739

5840
string endpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})";

src/ModelContextProtocol/Client/McpClientOptions.cs

+11-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ public class McpClientOptions
1212
/// <summary>
1313
/// Information about this client implementation.
1414
/// </summary>
15-
public required Implementation ClientInfo { get; set; }
15+
public Implementation? ClientInfo { get; set; }
1616

1717
/// <summary>
1818
/// Client capabilities to advertise to the server.
@@ -28,4 +28,14 @@ public class McpClientOptions
2828
/// Timeout for initialization sequence.
2929
/// </summary>
3030
public TimeSpan InitializationTimeout { get; set; } = TimeSpan.FromSeconds(60);
31+
32+
/// <summary>Creates a shallow clone of the options.</summary>
33+
internal McpClientOptions Clone() =>
34+
new()
35+
{
36+
ClientInfo = ClientInfo,
37+
Capabilities = Capabilities,
38+
ProtocolVersion = ProtocolVersion,
39+
InitializationTimeout = InitializationTimeout
40+
};
3141
}

src/ModelContextProtocol/IMcpEndpoint.cs

-16
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,4 @@ public interface IMcpEndpoint : IAsyncDisposable
1515
/// <param name="message">The message.</param>
1616
/// <param name="cancellationToken">A token to cancel the operation.</param>
1717
Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default);
18-
19-
/// <summary>
20-
/// Adds a handler for server notifications of a specific method.
21-
/// </summary>
22-
/// <param name="method">The notification method to handle.</param>
23-
/// <param name="handler">The async handler function to process notifications.</param>
24-
/// <remarks>
25-
/// <para>
26-
/// Each method may have multiple handlers. Adding a handler for a method that already has one
27-
/// will not replace the existing handler.
28-
/// </para>
29-
/// <para>
30-
/// <see cref="NotificationMethods"> provides constants for common notification methods.</see>
31-
/// </para>
32-
/// </remarks>
33-
void AddNotificationHandler(string method, Func<JsonRpcNotification, Task> handler);
3418
}

src/ModelContextProtocol/Protocol/Types/Capabilities.cs

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using ModelContextProtocol.Server;
1+
using ModelContextProtocol.Protocol.Messages;
2+
using ModelContextProtocol.Server;
23
using System.Text.Json.Serialization;
34

45
namespace ModelContextProtocol.Protocol.Types;
@@ -26,6 +27,14 @@ public class ClientCapabilities
2627
/// </summary>
2728
[JsonPropertyName("sampling")]
2829
public SamplingCapability? Sampling { get; set; }
30+
31+
/// <summary>Gets or sets notification handlers to register with the client.</summary>
32+
/// <remarks>
33+
/// When constructed, the client will enumerate these handlers, which may contain multiple handlers per key.
34+
/// The client will not re-enumerate the sequence.
35+
/// </remarks>
36+
[JsonIgnore]
37+
public IEnumerable<KeyValuePair<string, Func<JsonRpcNotification, Task>>>? NotificationHandlers { get; set; }
2938
}
3039

3140
/// <summary>

src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System.Text.Json.Serialization;
1+
using ModelContextProtocol.Protocol.Messages;
2+
using System.Text.Json.Serialization;
23

34
namespace ModelContextProtocol.Protocol.Types;
45

@@ -37,4 +38,12 @@ public class ServerCapabilities
3738
/// </summary>
3839
[JsonPropertyName("tools")]
3940
public ToolsCapability? Tools { get; set; }
41+
42+
/// <summary>Gets or sets notification handlers to register with the server.</summary>
43+
/// <remarks>
44+
/// When constructed, the server will enumerate these handlers, which may contain multiple handlers per key.
45+
/// The server will not re-enumerate the sequence.
46+
/// </remarks>
47+
[JsonIgnore]
48+
public IEnumerable<KeyValuePair<string, Func<JsonRpcNotification, Task>>>? NotificationHandlers { get; set; }
4049
}

src/ModelContextProtocol/Server/McpServer.cs

+5
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory?
6666
return Task.CompletedTask;
6767
});
6868

69+
if (options.Capabilities?.NotificationHandlers is { } notificationHandlers)
70+
{
71+
AddNotificationHandlers(notificationHandlers);
72+
}
73+
6974
SetToolsHandler(options);
7075
SetInitializeHandler(options);
7176
SetCompletionHandler(options);

src/ModelContextProtocol/Shared/McpEndpoint.cs

+24-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace ModelContextProtocol.Shared;
2020
internal abstract class McpEndpoint : IAsyncDisposable
2121
{
2222
private readonly RequestHandlers _requestHandlers = [];
23-
private readonly NotificationHandlers _notificationHandlers = [];
23+
private readonly Dictionary<string, List<Func<JsonRpcNotification, Task>>> _notificationHandlers = new();
2424

2525
private McpSession? _session;
2626
private CancellationTokenSource? _sessionCts;
@@ -39,16 +39,36 @@ protected McpEndpoint(ILoggerFactory? loggerFactory = null)
3939
_logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance;
4040
}
4141

42+
/// <summary>Sets a request handler as part of configuring the endpoint.</summary>
43+
/// <remarks>This method is not thread-safe and should only be used serially as part of configuring the instance.</remarks>
4244
protected void SetRequestHandler<TRequest, TResponse>(
4345
string method,
4446
Func<TRequest?, CancellationToken, Task<TResponse>> handler,
4547
JsonTypeInfo<TRequest> requestTypeInfo,
4648
JsonTypeInfo<TResponse> responseTypeInfo)
47-
4849
=> _requestHandlers.Set(method, handler, requestTypeInfo, responseTypeInfo);
4950

50-
public void AddNotificationHandler(string method, Func<JsonRpcNotification, Task> handler)
51-
=> _notificationHandlers.Add(method, handler);
51+
/// <summary>Adds a notification handler as part of configuring the endpoint.</summary>
52+
/// <remarks>This method is not thread-safe and should only be used serially as part of configuring the instance.</remarks>
53+
protected void AddNotificationHandler(string method, Func<JsonRpcNotification, Task> handler)
54+
{
55+
if (!_notificationHandlers.TryGetValue(method, out var handlers))
56+
{
57+
_notificationHandlers[method] = handlers = [];
58+
}
59+
60+
handlers.Add(handler);
61+
}
62+
63+
/// <summary>Adds notification handlers as part of configuring the endpoint.</summary>
64+
/// <remarks>This method is not thread-safe and should only be used serially as part of configuring the instance.</remarks>
65+
protected void AddNotificationHandlers(IEnumerable<KeyValuePair<string, Func<JsonRpcNotification, Task>>> handlers)
66+
{
67+
foreach (var handler in handlers)
68+
{
69+
AddNotificationHandler(handler.Key, handler.Value);
70+
}
71+
}
5272

5373
public Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default)
5474
=> GetSessionOrThrow().SendRequestAsync(request, cancellationToken);

src/ModelContextProtocol/Shared/McpSession.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ internal sealed class McpSession : IDisposable
3434
private readonly string _transportKind;
3535
private readonly ITransport _transport;
3636
private readonly RequestHandlers _requestHandlers;
37-
private readonly NotificationHandlers _notificationHandlers;
37+
private readonly Dictionary<string, List<Func<JsonRpcNotification, Task>>> _notificationHandlers;
3838
private readonly long _sessionStartingTimestamp = Stopwatch.GetTimestamp();
3939

4040
/// <summary>Collection of requests sent on this session and waiting for responses.</summary>
@@ -63,7 +63,7 @@ public McpSession(
6363
ITransport transport,
6464
string endpointName,
6565
RequestHandlers requestHandlers,
66-
NotificationHandlers notificationHandlers,
66+
Dictionary<string, List<Func<JsonRpcNotification, Task>>> notificationHandlers,
6767
ILogger logger)
6868
{
6969
Throw.IfNull(transport);

src/ModelContextProtocol/Shared/NotificationHandlers.cs

-16
This file was deleted.

tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs

+1-7
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,13 @@ public class ClientIntegrationTestFixture
88
{
99
private ILoggerFactory? _loggerFactory;
1010

11-
public McpClientOptions DefaultOptions { get; }
1211
public McpServerConfig EverythingServerConfig { get; }
1312
public McpServerConfig TestServerConfig { get; }
1413

1514
public static IEnumerable<string> ClientIds => ["everything", "test_server"];
1615

1716
public ClientIntegrationTestFixture()
1817
{
19-
DefaultOptions = new()
20-
{
21-
ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" },
22-
};
23-
2418
EverythingServerConfig = new()
2519
{
2620
Id = "everything",
@@ -63,5 +57,5 @@ public Task<IMcpClient> CreateClientAsync(string clientId, McpClientOptions? cli
6357
"everything" => EverythingServerConfig,
6458
"test_server" => TestServerConfig,
6559
_ => throw new ArgumentException($"Unknown client ID: {clientId}")
66-
}, clientOptions ?? DefaultOptions, loggerFactory: _loggerFactory);
60+
}, clientOptions, loggerFactory: _loggerFactory);
6761
}

0 commit comments

Comments
 (0)