From 3725859f27905e3d13dacfbcc95b08f30e9b04f0 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Fri, 4 Apr 2025 12:20:28 -0400 Subject: [PATCH 1/2] Move notification handler registrations to capabilities Currently request handlers are set on the capability objects, but notification handlers are set after construction via an AddNotificationHandler method on the IMcpEndpoint interface. This moves handler specification to be at construction as well. This makes it more consistent with request handlers, simplifies the IMcpEndpoint interface to just be about message sending, and avoids a concurrency bug that could occur if someone tried to add a handler while the endpoint was processing notifications. --- src/ModelContextProtocol/Client/McpClient.cs | 98 ++++++++++++------- .../Client/McpClientFactory.cs | 18 ---- .../Client/McpClientOptions.cs | 12 ++- src/ModelContextProtocol/IMcpEndpoint.cs | 16 --- .../Transport/McpTransportException.cs | 2 +- .../Transport/StdioClientSessionTransport.cs | 18 +++- .../Transport/StdioClientTransport.cs | 17 +++- .../Protocol/Types/Capabilities.cs | 11 ++- .../Protocol/Types/ServerCapabilities.cs | 11 ++- src/ModelContextProtocol/Server/McpServer.cs | 33 ++++--- .../Shared/McpEndpoint.cs | 16 +-- .../Shared/NotificationHandlers.cs | 22 ++++- .../ClientIntegrationTestFixture.cs | 8 +- .../ClientIntegrationTests.cs | 58 +++++++---- .../McpServerBuilderExtensionsPromptsTests.cs | 24 +++-- .../McpServerBuilderExtensionsToolsTests.cs | 51 ++++++---- .../ModelContextProtocol.Tests.csproj | 7 ++ .../Server/McpServerTests.cs | 15 +-- .../SseIntegrationTests.cs | 30 +++--- 19 files changed, 279 insertions(+), 188 deletions(-) diff --git a/src/ModelContextProtocol/Client/McpClient.cs b/src/ModelContextProtocol/Client/McpClient.cs index a38edf6d..ea21055e 100644 --- a/src/ModelContextProtocol/Client/McpClient.cs +++ b/src/ModelContextProtocol/Client/McpClient.cs @@ -5,6 +5,8 @@ using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Shared; using ModelContextProtocol.Utils.Json; +using System.Diagnostics; +using System.Reflection; using System.Text.Json; namespace ModelContextProtocol.Client; @@ -12,6 +14,9 @@ namespace ModelContextProtocol.Client; /// internal sealed class McpClient : McpEndpoint, IMcpClient { + /// Cached naming information used for client name/version when none is specified. + private static readonly AssemblyName s_asmName = (Assembly.GetEntryAssembly() ?? Assembly.GetExecutingAssembly()).GetName(); + private readonly IClientTransport _clientTransport; private readonly McpClientOptions _options; @@ -29,43 +34,61 @@ internal sealed class McpClient : McpEndpoint, IMcpClient /// Options for the client, defining protocol version and capabilities. /// The server configuration. /// The logger factory. - public McpClient(IClientTransport clientTransport, McpClientOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory) + public McpClient(IClientTransport clientTransport, McpClientOptions? options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory) : base(loggerFactory) { _clientTransport = clientTransport; + + if (options?.ClientInfo is null) + { + options = options?.Clone() ?? new(); + options.ClientInfo = new() + { + Name = s_asmName.Name ?? nameof(McpClient), + Version = s_asmName.Version?.ToString() ?? "1.0.0", + }; + } _options = options; EndpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})"; - if (options.Capabilities?.Sampling is { } samplingCapability) + if (options.Capabilities is { } capabilities) { - if (samplingCapability.SamplingHandler is not { } samplingHandler) + if (capabilities.NotificationHandlers is { } notificationHandlers) { - throw new InvalidOperationException($"Sampling capability was set but it did not provide a handler."); + NotificationHandlers.AddRange(notificationHandlers); } - SetRequestHandler( - RequestMethods.SamplingCreateMessage, - (request, cancellationToken) => samplingHandler( - request, - request?.Meta?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, - cancellationToken), - McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, - McpJsonUtilities.JsonContext.Default.CreateMessageResult); - } - - if (options.Capabilities?.Roots is { } rootsCapability) - { - if (rootsCapability.RootsHandler is not { } rootsHandler) + if (capabilities.Sampling is { } samplingCapability) { - throw new InvalidOperationException($"Roots capability was set but it did not provide a handler."); + if (samplingCapability.SamplingHandler is not { } samplingHandler) + { + throw new InvalidOperationException($"Sampling capability was set but it did not provide a handler."); + } + + RequestHandlers.Set( + RequestMethods.SamplingCreateMessage, + (request, cancellationToken) => samplingHandler( + request, + request?.Meta?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, + cancellationToken), + McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, + McpJsonUtilities.JsonContext.Default.CreateMessageResult); } - SetRequestHandler( - RequestMethods.RootsList, - rootsHandler, - McpJsonUtilities.JsonContext.Default.ListRootsRequestParams, - McpJsonUtilities.JsonContext.Default.ListRootsResult); + if (capabilities.Roots is { } rootsCapability) + { + if (rootsCapability.RootsHandler is not { } rootsHandler) + { + throw new InvalidOperationException($"Roots capability was set but it did not provide a handler."); + } + + RequestHandlers.Set( + RequestMethods.RootsList, + rootsHandler, + McpJsonUtilities.JsonContext.Default.ListRootsRequestParams, + McpJsonUtilities.JsonContext.Default.ListRootsResult); + } } } @@ -96,20 +119,21 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); initializationCts.CancelAfter(_options.InitializationTimeout); - try - { - // Send initialize request - var initializeResponse = await this.SendRequestAsync( - RequestMethods.Initialize, - new InitializeRequestParams - { - ProtocolVersion = _options.ProtocolVersion, - Capabilities = _options.Capabilities ?? new ClientCapabilities(), - ClientInfo = _options.ClientInfo - }, - McpJsonUtilities.JsonContext.Default.InitializeRequestParams, - McpJsonUtilities.JsonContext.Default.InitializeResult, - cancellationToken: initializationCts.Token).ConfigureAwait(false); + try + { + // Send initialize request + Debug.Assert(_options.ClientInfo is not null, "ClientInfo should be set by the constructor"); + var initializeResponse = await this.SendRequestAsync( + RequestMethods.Initialize, + new InitializeRequestParams + { + ProtocolVersion = _options.ProtocolVersion, + Capabilities = _options.Capabilities ?? new ClientCapabilities(), + ClientInfo = _options.ClientInfo! + }, + McpJsonUtilities.JsonContext.Default.InitializeRequestParams, + McpJsonUtilities.JsonContext.Default.InitializeResult, + cancellationToken: initializationCts.Token).ConfigureAwait(false); // Store server information _logger.ServerCapabilitiesReceived(EndpointName, diff --git a/src/ModelContextProtocol/Client/McpClientFactory.cs b/src/ModelContextProtocol/Client/McpClientFactory.cs index c1bb3770..219c2640 100644 --- a/src/ModelContextProtocol/Client/McpClientFactory.cs +++ b/src/ModelContextProtocol/Client/McpClientFactory.cs @@ -12,23 +12,6 @@ namespace ModelContextProtocol.Client; /// Provides factory methods for creating MCP clients. public static class McpClientFactory { - /// Default client options to use when none are supplied. - private static readonly McpClientOptions s_defaultClientOptions = CreateDefaultClientOptions(); - - /// Creates default client options to use when no options are supplied. - private static McpClientOptions CreateDefaultClientOptions() - { - var asmName = (Assembly.GetEntryAssembly() ?? Assembly.GetExecutingAssembly()).GetName(); - return new() - { - ClientInfo = new() - { - Name = asmName.Name ?? "McpClient", - Version = asmName.Version?.ToString() ?? "1.0.0", - }, - }; - } - /// Creates an , connecting it to the specified server. /// Configuration for the target server to which the client should connect. /// @@ -52,7 +35,6 @@ public static async Task CreateAsync( { Throw.IfNull(serverConfig); - clientOptions ??= s_defaultClientOptions; createTransportFunc ??= CreateTransport; string endpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})"; diff --git a/src/ModelContextProtocol/Client/McpClientOptions.cs b/src/ModelContextProtocol/Client/McpClientOptions.cs index de61f18f..4ed05af4 100644 --- a/src/ModelContextProtocol/Client/McpClientOptions.cs +++ b/src/ModelContextProtocol/Client/McpClientOptions.cs @@ -12,7 +12,7 @@ public class McpClientOptions /// /// Information about this client implementation. /// - public required Implementation ClientInfo { get; set; } + public Implementation? ClientInfo { get; set; } /// /// Client capabilities to advertise to the server. @@ -28,4 +28,14 @@ public class McpClientOptions /// Timeout for initialization sequence. /// public TimeSpan InitializationTimeout { get; set; } = TimeSpan.FromSeconds(60); + + /// Creates a shallow clone of the options. + internal McpClientOptions Clone() => + new() + { + ClientInfo = ClientInfo, + Capabilities = Capabilities, + ProtocolVersion = ProtocolVersion, + InitializationTimeout = InitializationTimeout + }; } diff --git a/src/ModelContextProtocol/IMcpEndpoint.cs b/src/ModelContextProtocol/IMcpEndpoint.cs index a053ad64..95d6dbc5 100644 --- a/src/ModelContextProtocol/IMcpEndpoint.cs +++ b/src/ModelContextProtocol/IMcpEndpoint.cs @@ -15,20 +15,4 @@ public interface IMcpEndpoint : IAsyncDisposable /// The message. /// The to monitor for cancellation requests. The default is . Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default); - - /// - /// Adds a handler for server notifications of a specific method. - /// - /// The notification method to handle. - /// The async handler function to process notifications. - /// - /// - /// Each method may have multiple handlers. Adding a handler for a method that already has one - /// will not replace the existing handler. - /// - /// - /// provides constants for common notification methods. - /// - /// - void AddNotificationHandler(string method, Func handler); } diff --git a/src/ModelContextProtocol/Protocol/Transport/McpTransportException.cs b/src/ModelContextProtocol/Protocol/Transport/McpTransportException.cs index bc49376d..d84de632 100644 --- a/src/ModelContextProtocol/Protocol/Transport/McpTransportException.cs +++ b/src/ModelContextProtocol/Protocol/Transport/McpTransportException.cs @@ -30,7 +30,7 @@ public McpTransportException(string message) /// /// The message that describes the error. /// The exception that is the cause of the current exception. - public McpTransportException(string message, Exception innerException) + public McpTransportException(string message, Exception? innerException) : base(message, innerException) { } diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs index af304c89..b86c886b 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs @@ -21,10 +21,22 @@ public StdioClientSessionTransport(StdioClientTransportOptions options, Process /// public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) { - if (_process.HasExited) + Exception? processException = null; + bool hasExited = false; + try + { + hasExited = _process.HasExited; + } + catch (Exception e) + { + processException = e; + hasExited = true; + } + + if (hasExited) { Logger.TransportNotConnected(EndpointName); - throw new McpTransportException("Transport is not connected"); + throw new McpTransportException("Transport is not connected", processException); } await base.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); @@ -33,7 +45,7 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio /// protected override ValueTask CleanupAsync(CancellationToken cancellationToken) { - StdioClientTransport.DisposeProcess(_process, processStarted: true, Logger, _options.ShutdownTimeout, EndpointName); + StdioClientTransport.DisposeProcess(_process, processRunning: true, Logger, _options.ShutdownTimeout, EndpointName); return base.CleanupAsync(cancellationToken); } diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs index 774677f1..89c782c1 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs @@ -2,6 +2,7 @@ using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Logging; using ModelContextProtocol.Utils; +using System.ComponentModel; using System.Diagnostics; using System.Text; @@ -129,13 +130,25 @@ public async Task ConnectAsync(CancellationToken cancellationToken = } internal static void DisposeProcess( - Process? process, bool processStarted, ILogger logger, TimeSpan shutdownTimeout, string endpointName) + Process? process, bool processRunning, ILogger logger, TimeSpan shutdownTimeout, string endpointName) { if (process is not null) { + if (processRunning) + { + try + { + processRunning = !process.HasExited; + } + catch + { + processRunning = false; + } + } + try { - if (processStarted && !process.HasExited) + if (processRunning) { // Wait for the process to exit. // Kill the while process tree because the process may spawn child processes diff --git a/src/ModelContextProtocol/Protocol/Types/Capabilities.cs b/src/ModelContextProtocol/Protocol/Types/Capabilities.cs index c0cf4197..a29a530a 100644 --- a/src/ModelContextProtocol/Protocol/Types/Capabilities.cs +++ b/src/ModelContextProtocol/Protocol/Types/Capabilities.cs @@ -1,4 +1,5 @@ -using ModelContextProtocol.Server; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Server; using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol.Types; @@ -26,6 +27,14 @@ public class ClientCapabilities /// [JsonPropertyName("sampling")] public SamplingCapability? Sampling { get; set; } + + /// Gets or sets notification handlers to register with the client. + /// + /// When constructed, the client will enumerate these handlers, which may contain multiple handlers per key. + /// The client will not re-enumerate the sequence. + /// + [JsonIgnore] + public IEnumerable>>? NotificationHandlers { get; set; } } /// diff --git a/src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs b/src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs index 296e4e62..9f8e3ac8 100644 --- a/src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs +++ b/src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs @@ -1,4 +1,5 @@ -using System.Text.Json.Serialization; +using ModelContextProtocol.Protocol.Messages; +using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol.Types; @@ -37,4 +38,12 @@ public class ServerCapabilities /// [JsonPropertyName("tools")] public ToolsCapability? Tools { get; set; } + + /// Gets or sets notification handlers to register with the server. + /// + /// When constructed, the server will enumerate these handlers, which may contain multiple handlers per key. + /// The server will not re-enumerate the sequence. + /// + [JsonIgnore] + public IEnumerable>>? NotificationHandlers { get; set; } } diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index 764736e2..f04371b2 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -51,7 +51,7 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? }); }; - AddNotificationHandler(NotificationMethods.InitializedNotification, _ => + NotificationHandlers.Add(NotificationMethods.InitializedNotification, _ => { if (ServerOptions.Capabilities?.Tools?.ToolCollection is { } tools) { @@ -66,6 +66,11 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? return Task.CompletedTask; }); + if (options.Capabilities?.NotificationHandlers is { } notificationHandlers) + { + NotificationHandlers.AddRange(notificationHandlers); + } + SetToolsHandler(options); SetInitializeHandler(options); SetCompletionHandler(options); @@ -131,7 +136,7 @@ public override async ValueTask DisposeUnsynchronizedAsync() private void SetPingHandler() { - SetRequestHandler(RequestMethods.Ping, + RequestHandlers.Set(RequestMethods.Ping, (request, _) => Task.FromResult(new PingResult()), McpJsonUtilities.JsonContext.Default.JsonNode, McpJsonUtilities.JsonContext.Default.PingResult); @@ -139,7 +144,7 @@ private void SetPingHandler() private void SetInitializeHandler(McpServerOptions options) { - SetRequestHandler(RequestMethods.Initialize, + RequestHandlers.Set(RequestMethods.Initialize, (request, _) => { ClientCapabilities = request?.Capabilities ?? new(); @@ -164,7 +169,7 @@ private void SetInitializeHandler(McpServerOptions options) private void SetCompletionHandler(McpServerOptions options) { // This capability is not optional, so return an empty result if there is no handler. - SetRequestHandler(RequestMethods.CompletionComplete, + RequestHandlers.Set(RequestMethods.CompletionComplete, options.GetCompletionHandler is { } handler ? (request, ct) => handler(new(this, request), ct) : (request, ct) => Task.FromResult(new CompleteResult() { Completion = new() { Values = [], Total = 0, HasMore = false } }), @@ -190,20 +195,20 @@ private void SetResourcesHandler(McpServerOptions options) listResourcesHandler ??= (static (_, _) => Task.FromResult(new ListResourcesResult())); - SetRequestHandler( + RequestHandlers.Set( RequestMethods.ResourcesList, (request, ct) => listResourcesHandler(new(this, request), ct), McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, McpJsonUtilities.JsonContext.Default.ListResourcesResult); - SetRequestHandler( + RequestHandlers.Set( RequestMethods.ResourcesRead, (request, ct) => readResourceHandler(new(this, request), ct), McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, McpJsonUtilities.JsonContext.Default.ReadResourceResult); listResourceTemplatesHandler ??= (static (_, _) => Task.FromResult(new ListResourceTemplatesResult())); - SetRequestHandler( + RequestHandlers.Set( RequestMethods.ResourcesTemplatesList, (request, ct) => listResourceTemplatesHandler(new(this, request), ct), McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, @@ -221,13 +226,13 @@ private void SetResourcesHandler(McpServerOptions options) throw new McpException("Resources capability was enabled with subscribe support, but SubscribeToResources and/or UnsubscribeFromResources handlers were not specified."); } - SetRequestHandler( + RequestHandlers.Set( RequestMethods.ResourcesSubscribe, (request, ct) => subscribeHandler(new(this, request), ct), McpJsonUtilities.JsonContext.Default.SubscribeRequestParams, McpJsonUtilities.JsonContext.Default.EmptyResult); - SetRequestHandler( + RequestHandlers.Set( RequestMethods.ResourcesUnsubscribe, (request, ct) => unsubscribeHandler(new(this, request), ct), McpJsonUtilities.JsonContext.Default.UnsubscribeRequestParams, @@ -314,13 +319,13 @@ await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(fals } } - SetRequestHandler( + RequestHandlers.Set( RequestMethods.PromptsList, (request, ct) => listPromptsHandler(new(this, request), ct), McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, McpJsonUtilities.JsonContext.Default.ListPromptsResult); - SetRequestHandler( + RequestHandlers.Set( RequestMethods.PromptsGet, (request, ct) => getPromptHandler(new(this, request), ct), McpJsonUtilities.JsonContext.Default.GetPromptRequestParams, @@ -407,13 +412,13 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) } } - SetRequestHandler( + RequestHandlers.Set( RequestMethods.ToolsList, (request, ct) => listToolsHandler(new(this, request), ct), McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, McpJsonUtilities.JsonContext.Default.ListToolsResult); - SetRequestHandler( + RequestHandlers.Set( RequestMethods.ToolsCall, (request, ct) => callToolHandler(new(this, request), ct), McpJsonUtilities.JsonContext.Default.CallToolRequestParams, @@ -432,7 +437,7 @@ private void SetSetLoggingLevelHandler(McpServerOptions options) throw new McpException("Logging capability was enabled, but SetLoggingLevelHandler was not specified."); } - SetRequestHandler( + RequestHandlers.Set( RequestMethods.LoggingSetLevel, (request, ct) => setLoggingLevelHandler(new(this, request), ct), McpJsonUtilities.JsonContext.Default.SetLevelRequestParams, diff --git a/src/ModelContextProtocol/Shared/McpEndpoint.cs b/src/ModelContextProtocol/Shared/McpEndpoint.cs index 8b50a805..8587da98 100644 --- a/src/ModelContextProtocol/Shared/McpEndpoint.cs +++ b/src/ModelContextProtocol/Shared/McpEndpoint.cs @@ -19,9 +19,6 @@ namespace ModelContextProtocol.Shared; /// internal abstract class McpEndpoint : IAsyncDisposable { - private readonly RequestHandlers _requestHandlers = []; - private readonly NotificationHandlers _notificationHandlers = []; - private McpSession? _session; private CancellationTokenSource? _sessionCts; @@ -39,16 +36,9 @@ protected McpEndpoint(ILoggerFactory? loggerFactory = null) _logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance; } - protected void SetRequestHandler( - string method, - Func> handler, - JsonTypeInfo requestTypeInfo, - JsonTypeInfo responseTypeInfo) - - => _requestHandlers.Set(method, handler, requestTypeInfo, responseTypeInfo); + protected RequestHandlers RequestHandlers { get; } = []; - public void AddNotificationHandler(string method, Func handler) - => _notificationHandlers.Add(method, handler); + protected NotificationHandlers NotificationHandlers { get; } = []; public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) => GetSessionOrThrow().SendRequestAsync(request, cancellationToken); @@ -70,7 +60,7 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella protected void StartSession(ITransport sessionTransport) { _sessionCts = new CancellationTokenSource(); - _session = new McpSession(this is IMcpServer, sessionTransport, EndpointName, _requestHandlers, _notificationHandlers, _logger); + _session = new McpSession(this is IMcpServer, sessionTransport, EndpointName, RequestHandlers, NotificationHandlers, _logger); MessageProcessingTask = _session.ProcessMessagesAsync(_sessionCts.Token); } diff --git a/src/ModelContextProtocol/Shared/NotificationHandlers.cs b/src/ModelContextProtocol/Shared/NotificationHandlers.cs index 1fdb0578..c3b5dfe7 100644 --- a/src/ModelContextProtocol/Shared/NotificationHandlers.cs +++ b/src/ModelContextProtocol/Shared/NotificationHandlers.cs @@ -1,16 +1,28 @@ using ModelContextProtocol.Protocol.Messages; -using System.Collections.Concurrent; namespace ModelContextProtocol.Shared; -internal sealed class NotificationHandlers : ConcurrentDictionary>> +internal sealed class NotificationHandlers : Dictionary>> { + /// Adds a notification handler as part of configuring the endpoint. + /// This method is not thread-safe and should only be used serially as part of configuring the instance. public void Add(string method, Func handler) { - var handlers = GetOrAdd(method, _ => []); - lock (handlers) + if (!TryGetValue(method, out var handlers)) { - handlers.Add(handler); + this[method] = handlers = []; + } + + handlers.Add(handler); + } + + /// Adds notification handlers as part of configuring the endpoint. + /// This method is not thread-safe and should only be used serially as part of configuring the instance. + public void AddRange(IEnumerable>> handlers) + { + foreach (var handler in handlers) + { + Add(handler.Key, handler.Value); } } } diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs index a77ae6b8..400caade 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs @@ -8,7 +8,6 @@ public class ClientIntegrationTestFixture { private ILoggerFactory? _loggerFactory; - public McpClientOptions DefaultOptions { get; } public McpServerConfig EverythingServerConfig { get; } public McpServerConfig TestServerConfig { get; } @@ -16,11 +15,6 @@ public class ClientIntegrationTestFixture public ClientIntegrationTestFixture() { - DefaultOptions = new() - { - ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" }, - }; - EverythingServerConfig = new() { Id = "everything", @@ -63,5 +57,5 @@ public Task CreateClientAsync(string clientId, McpClientOptions? cli "everything" => EverythingServerConfig, "test_server" => TestServerConfig, _ => throw new ArgumentException($"Unknown client ID: {clientId}") - }, clientOptions ?? DefaultOptions, loggerFactory: _loggerFactory); + }, clientOptions, loggerFactory: _loggerFactory); } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index 705a0779..d808e40e 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -256,13 +256,22 @@ public async Task SubscribeResource_Stdio() // act TaskCompletionSource tcs = new(); - await using var client = await _fixture.CreateClientAsync(clientId); - client.AddNotificationHandler(NotificationMethods.ResourceUpdatedNotification, (notification) => + await using var client = await _fixture.CreateClientAsync(clientId, new() { - var notificationParams = JsonSerializer.Deserialize(notification.Params); - tcs.TrySetResult(true); - return Task.CompletedTask; + Capabilities = new() + { + NotificationHandlers = + [ + new(NotificationMethods.ResourceUpdatedNotification, notification => + { + var notificationParams = JsonSerializer.Deserialize(notification.Params); + tcs.TrySetResult(true); + return Task.CompletedTask; + }) + ] + } }); + await client.SubscribeToResourceAsync("test://static/resource/1", TestContext.Current.CancellationToken); await tcs.Task; @@ -277,12 +286,20 @@ public async Task UnsubscribeResource_Stdio() // act TaskCompletionSource receivedNotification = new(); - await using var client = await _fixture.CreateClientAsync(clientId); - client.AddNotificationHandler(NotificationMethods.ResourceUpdatedNotification, (notification) => + await using var client = await _fixture.CreateClientAsync(clientId, new() { - var notificationParams = JsonSerializer.Deserialize(notification.Params); - receivedNotification.TrySetResult(true); - return Task.CompletedTask; + Capabilities = new() + { + NotificationHandlers = + [ + new(NotificationMethods.ResourceUpdatedNotification, (notification) => + { + var notificationParams = JsonSerializer.Deserialize(notification.Params); + receivedNotification.TrySetResult(true); + return Task.CompletedTask; + }) + ] + } }); await client.SubscribeToResourceAsync("test://static/resource/1", TestContext.Current.CancellationToken); @@ -483,7 +500,6 @@ public async Task ListToolsAsync_UsingEverythingServer_ToolsAreProperlyCalled() // Get the MCP client and tools from it. await using var client = await McpClientFactory.CreateAsync( _fixture.EverythingServerConfig, - _fixture.DefaultOptions, cancellationToken: TestContext.Current.CancellationToken); var mappedTools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); @@ -543,15 +559,23 @@ public async Task SamplingViaChatClient_RequestResponseProperlyPropagated() public async Task SetLoggingLevel_ReceivesLoggingMessages(string clientId) { TaskCompletionSource receivedNotification = new(); - await using var client = await _fixture.CreateClientAsync(clientId); - client.AddNotificationHandler(NotificationMethods.LoggingMessageNotification, (notification) => + await using var client = await _fixture.CreateClientAsync(clientId, new() { - var loggingMessageNotificationParameters = JsonSerializer.Deserialize(notification.Params); - if (loggingMessageNotificationParameters is not null) + Capabilities = new() { - receivedNotification.TrySetResult(true); + NotificationHandlers = + [ + new(NotificationMethods.LoggingMessageNotification, (notification) => + { + var loggingMessageNotificationParameters = JsonSerializer.Deserialize(notification.Params); + if (loggingMessageNotificationParameters is not null) + { + receivedNotification.TrySetResult(true); + } + return Task.CompletedTask; + }) + ] } - return Task.CompletedTask; }); // act diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs index bbbbb05d..ed1be347 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs @@ -117,7 +117,7 @@ public async ValueTask DisposeAsync() Dispose(); } - private async Task CreateMcpClientForServer() + private async Task CreateMcpClientForServer(McpClientOptions? options = null) { return await McpClientFactory.CreateAsync( new McpServerConfig() @@ -126,6 +126,7 @@ private async Task CreateMcpClientForServer() Name = "TestServer", TransportType = "ignored", }, + options, createTransportFunc: (_, _) => new StreamClientTransport( serverInput: _clientToServerPipe.Writer.AsStream(), serverOutput: _serverToClientPipe.Reader.AsStream(), @@ -175,18 +176,23 @@ public async Task Can_List_And_Call_Registered_Prompts() [Fact] public async Task Can_Be_Notified_Of_Prompt_Changes() { - IMcpClient client = await CreateMcpClientForServer(); - - var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); - Assert.Equal(6, prompts.Count); - Channel listChanged = Channel.CreateUnbounded(); - client.AddNotificationHandler("notifications/prompts/list_changed", notification => + + IMcpClient client = await CreateMcpClientForServer(new() { - listChanged.Writer.TryWrite(notification); - return Task.CompletedTask; + Capabilities = new() + { + NotificationHandlers = [new("notifications/prompts/list_changed", notification => + { + listChanged.Writer.TryWrite(notification); + return Task.CompletedTask; + })], + }, }); + var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); + Assert.Equal(6, prompts.Count); + var notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken); Assert.False(notificationRead.IsCompleted); diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 79ae117f..827c7056 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -141,15 +141,16 @@ public async ValueTask DisposeAsync() Dispose(); } - private async Task CreateMcpClientForServer() + private async Task CreateMcpClientForServer(McpClientOptions? options = null) { return await McpClientFactory.CreateAsync( - new McpServerConfig() - { - Id = "TestServer", - Name = "TestServer", - TransportType = "ignored", - }, + new McpServerConfig() + { + Id = "TestServer", + Name = "TestServer", + TransportType = "ignored", + }, + options, createTransportFunc: (_, _) => new StreamClientTransport( serverInput: _clientToServerPipe.Writer.AsStream(), _serverToClientPipe.Reader.AsStream(), @@ -242,18 +243,23 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T [Fact] public async Task Can_Be_Notified_Of_Tool_Changes() { - IMcpClient client = await CreateMcpClientForServer(); - - var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); - Assert.Equal(16, tools.Count); - Channel listChanged = Channel.CreateUnbounded(); - client.AddNotificationHandler(NotificationMethods.ToolListChangedNotification, notification => + + IMcpClient client = await CreateMcpClientForServer(new() { - listChanged.Writer.TryWrite(notification); - return Task.CompletedTask; + Capabilities = new() + { + NotificationHandlers = [new(NotificationMethods.ToolListChangedNotification, notification => + { + listChanged.Writer.TryWrite(notification); + return Task.CompletedTask; + })], + }, }); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + Assert.Equal(16, tools.Count); + var notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken); Assert.False(notificationRead.IsCompleted); @@ -622,12 +628,17 @@ public async Task HandlesIProgressParameter() { ConcurrentQueue notifications = new(); - IMcpClient client = await CreateMcpClientForServer(); - client.AddNotificationHandler(NotificationMethods.ProgressNotification, notification => + IMcpClient client = await CreateMcpClientForServer(new() { - ProgressNotification pn = JsonSerializer.Deserialize(notification.Params)!; - notifications.Enqueue(pn); - return Task.CompletedTask; + Capabilities = new() + { + NotificationHandlers = [new(NotificationMethods.ProgressNotification, notification => + { + ProgressNotification pn = JsonSerializer.Deserialize(notification.Params)!; + notifications.Enqueue(pn); + return Task.CompletedTask; + })], + }, }); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj index c8dc1ca9..16e6a891 100644 --- a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj +++ b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj @@ -11,6 +11,13 @@ ModelContextProtocol.Tests + + + true + + runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 619dcfde..0a5bec09 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -634,8 +634,6 @@ public Task SendRequestAsync(JsonRpcRequest request, Cancellati public Implementation? ClientInfo => throw new NotImplementedException(); public McpServerOptions ServerOptions => throw new NotImplementedException(); public IServiceProvider? Services => throw new NotImplementedException(); - public void AddNotificationHandler(string method, Func handler) => - throw new NotImplementedException(); public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) => throw new NotImplementedException(); public Task RunAsync(CancellationToken cancellationToken = default) => @@ -649,13 +647,16 @@ public async Task NotifyProgress_Should_Be_Handled() var options = CreateOptions(); var notificationReceived = new TaskCompletionSource(); + options.Capabilities = new() + { + NotificationHandlers = [new(NotificationMethods.ProgressNotification, notification => + { + notificationReceived.TrySetResult(notification); + return Task.CompletedTask; + })], + }; var server = McpServerFactory.Create(transport, options, LoggerFactory); - server.AddNotificationHandler(NotificationMethods.ProgressNotification, notification => - { - notificationReceived.SetResult(notification); - return Task.CompletedTask; - }); Task serverTask = server.RunAsync(TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs index 0622e656..11caaef8 100644 --- a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs @@ -213,12 +213,6 @@ public async Task ConnectAndReceiveNotification_InMemoryServer() await using InMemoryTestSseServer server = new(CreatePortNumber(), LoggerFactory.CreateLogger()); await server.StartAsync(); - - var defaultOptions = new McpClientOptions - { - ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" } - }; - var defaultConfig = new McpServerConfig { Id = "test_server", @@ -229,24 +223,28 @@ public async Task ConnectAndReceiveNotification_InMemoryServer() }; // Act + var receivedNotification = new TaskCompletionSource(); await using var client = await McpClientFactory.CreateAsync( defaultConfig, - defaultOptions, + new() + { + Capabilities = new() + { + NotificationHandlers = [new("test/notification", args => + { + var msg = args.Params?["message"]?.GetValue(); + receivedNotification.SetResult(msg); + + return Task.CompletedTask; + })], + }, + }, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); // Wait for SSE connection to be established await server.WaitForConnectionAsync(TimeSpan.FromSeconds(10)); - var receivedNotification = new TaskCompletionSource(); - client.AddNotificationHandler("test/notification", (args) => - { - var msg = args.Params?["message"]?.GetValue(); - receivedNotification.SetResult(msg); - - return Task.CompletedTask; - }); - // Act await server.SendTestNotificationAsync("Hello from server!"); From 824d01d1dd71e209d3d231b6ac4a014550e41152 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Fri, 4 Apr 2025 13:01:58 -0400 Subject: [PATCH 2/2] Address more feedback and further cleanup --- .../Tools/WeatherTools.cs | 1 - src/ModelContextProtocol/Client/McpClient.cs | 24 +++++---------- .../Client/McpClientFactory.cs | 1 - .../Client/McpClientOptions.cs | 10 ------- .../Client/McpClientTool.cs | 1 - .../Configuration/McpServerOptionsSetup.cs | 15 +--------- .../Transport/StdioClientTransport.cs | 1 - .../Transport/StdioServerTransport.cs | 4 +-- .../Protocol/Types/ListRootsRequestParams.cs | 4 +-- src/ModelContextProtocol/Server/McpServer.cs | 12 ++++++-- .../Server/McpServerExtensions.cs | 1 - .../Server/McpServerOptions.cs | 2 +- .../Shared/McpEndpoint.cs | 5 +++- src/ModelContextProtocol/Shared/McpSession.cs | 1 - .../Program.cs | 1 - .../Program.cs | 1 - .../Client/McpClientFactoryTests.cs | 30 +++++-------------- .../ClientIntegrationTests.cs | 5 ---- .../DiagnosticTests.cs | 2 -- .../Server/McpServerFactoryTests.cs | 4 +-- .../Server/McpServerTests.cs | 5 ++-- .../SseIntegrationTests.cs | 17 ----------- .../SseServerIntegrationTestFixture.cs | 8 +---- .../SseServerIntegrationTests.cs | 6 ++-- .../Transport/StdioServerTransportTests.cs | 8 ----- 25 files changed, 41 insertions(+), 128 deletions(-) diff --git a/samples/QuickstartWeatherServer/Tools/WeatherTools.cs b/samples/QuickstartWeatherServer/Tools/WeatherTools.cs index 8463e350..9f04ff25 100644 --- a/samples/QuickstartWeatherServer/Tools/WeatherTools.cs +++ b/samples/QuickstartWeatherServer/Tools/WeatherTools.cs @@ -1,7 +1,6 @@ using ModelContextProtocol; using ModelContextProtocol.Server; using System.ComponentModel; -using System.Net.Http.Json; using System.Text.Json; namespace QuickstartWeatherServer.Tools; diff --git a/src/ModelContextProtocol/Client/McpClient.cs b/src/ModelContextProtocol/Client/McpClient.cs index ea21055e..64a89680 100644 --- a/src/ModelContextProtocol/Client/McpClient.cs +++ b/src/ModelContextProtocol/Client/McpClient.cs @@ -5,8 +5,6 @@ using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Shared; using ModelContextProtocol.Utils.Json; -using System.Diagnostics; -using System.Reflection; using System.Text.Json; namespace ModelContextProtocol.Client; @@ -14,8 +12,11 @@ namespace ModelContextProtocol.Client; /// internal sealed class McpClient : McpEndpoint, IMcpClient { - /// Cached naming information used for client name/version when none is specified. - private static readonly AssemblyName s_asmName = (Assembly.GetEntryAssembly() ?? Assembly.GetExecutingAssembly()).GetName(); + private static Implementation DefaultImplementation { get; } = new() + { + Name = DefaultAssemblyName.Name ?? nameof(McpClient), + Version = DefaultAssemblyName.Version?.ToString() ?? "1.0.0", + }; private readonly IClientTransport _clientTransport; private readonly McpClientOptions _options; @@ -37,17 +38,9 @@ internal sealed class McpClient : McpEndpoint, IMcpClient public McpClient(IClientTransport clientTransport, McpClientOptions? options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory) : base(loggerFactory) { - _clientTransport = clientTransport; + options ??= new(); - if (options?.ClientInfo is null) - { - options = options?.Clone() ?? new(); - options.ClientInfo = new() - { - Name = s_asmName.Name ?? nameof(McpClient), - Version = s_asmName.Version?.ToString() ?? "1.0.0", - }; - } + _clientTransport = clientTransport; _options = options; EndpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})"; @@ -122,14 +115,13 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) try { // Send initialize request - Debug.Assert(_options.ClientInfo is not null, "ClientInfo should be set by the constructor"); var initializeResponse = await this.SendRequestAsync( RequestMethods.Initialize, new InitializeRequestParams { ProtocolVersion = _options.ProtocolVersion, Capabilities = _options.Capabilities ?? new ClientCapabilities(), - ClientInfo = _options.ClientInfo! + ClientInfo = _options.ClientInfo ?? DefaultImplementation, }, McpJsonUtilities.JsonContext.Default.InitializeRequestParams, McpJsonUtilities.JsonContext.Default.InitializeResult, diff --git a/src/ModelContextProtocol/Client/McpClientFactory.cs b/src/ModelContextProtocol/Client/McpClientFactory.cs index 219c2640..751df190 100644 --- a/src/ModelContextProtocol/Client/McpClientFactory.cs +++ b/src/ModelContextProtocol/Client/McpClientFactory.cs @@ -5,7 +5,6 @@ using ModelContextProtocol.Utils; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; -using System.Reflection; namespace ModelContextProtocol.Client; diff --git a/src/ModelContextProtocol/Client/McpClientOptions.cs b/src/ModelContextProtocol/Client/McpClientOptions.cs index 4ed05af4..5fe81ede 100644 --- a/src/ModelContextProtocol/Client/McpClientOptions.cs +++ b/src/ModelContextProtocol/Client/McpClientOptions.cs @@ -28,14 +28,4 @@ public class McpClientOptions /// Timeout for initialization sequence. /// public TimeSpan InitializationTimeout { get; set; } = TimeSpan.FromSeconds(60); - - /// Creates a shallow clone of the options. - internal McpClientOptions Clone() => - new() - { - ClientInfo = ClientInfo, - Capabilities = Capabilities, - ProtocolVersion = ProtocolVersion, - InitializationTimeout = InitializationTimeout - }; } diff --git a/src/ModelContextProtocol/Client/McpClientTool.cs b/src/ModelContextProtocol/Client/McpClientTool.cs index 58cb02d5..9dfee5a0 100644 --- a/src/ModelContextProtocol/Client/McpClientTool.cs +++ b/src/ModelContextProtocol/Client/McpClientTool.cs @@ -1,6 +1,5 @@ using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Utils.Json; -using ModelContextProtocol.Utils; using Microsoft.Extensions.AI; using System.Text.Json; diff --git a/src/ModelContextProtocol/Configuration/McpServerOptionsSetup.cs b/src/ModelContextProtocol/Configuration/McpServerOptionsSetup.cs index 2ef0300e..c778fe81 100644 --- a/src/ModelContextProtocol/Configuration/McpServerOptionsSetup.cs +++ b/src/ModelContextProtocol/Configuration/McpServerOptionsSetup.cs @@ -1,5 +1,4 @@ -using System.Reflection; -using ModelContextProtocol.Server; +using ModelContextProtocol.Server; using Microsoft.Extensions.Options; using ModelContextProtocol.Utils; @@ -25,18 +24,6 @@ public void Configure(McpServerOptions options) { Throw.IfNull(options); - // Configure the option's server information based on the current process, - // if it otherwise lacks server information. - if (options.ServerInfo is not { } serverInfo) - { - var assemblyName = Assembly.GetEntryAssembly()?.GetName(); - options.ServerInfo = new() - { - Name = assemblyName?.Name ?? "McpServer", - Version = assemblyName?.Version?.ToString() ?? "1.0.0", - }; - } - // Collect all of the provided tools into a tools collection. If the options already has // a collection, add to it, otherwise create a new one. We want to maintain the identity // of an existing collection in case someone has provided their own derived type, wants diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs index 89c782c1..11a0c137 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs @@ -2,7 +2,6 @@ using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Logging; using ModelContextProtocol.Utils; -using System.ComponentModel; using System.Diagnostics; using System.Text; diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs index 58077dbb..b3ae2e41 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs @@ -70,9 +70,7 @@ public StdioServerTransport(string serverName, ILoggerFactory? loggerFactory = n private static string GetServerName(McpServerOptions serverOptions) { Throw.IfNull(serverOptions); - Throw.IfNull(serverOptions.ServerInfo); - Throw.IfNull(serverOptions.ServerInfo.Name); - return serverOptions.ServerInfo.Name; + return serverOptions.ServerInfo?.Name ?? McpServer.DefaultImplementation.Name; } } diff --git a/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs index a5eec7a1..273251e6 100644 --- a/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs +++ b/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs @@ -1,6 +1,4 @@ -using ModelContextProtocol.Protocol.Messages; - -namespace ModelContextProtocol.Protocol.Types; +namespace ModelContextProtocol.Protocol.Types; /// /// A request from the server to get a list of root URIs from the client. diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index f04371b2..9b6898d8 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -11,6 +11,12 @@ namespace ModelContextProtocol.Server; /// internal sealed class McpServer : McpEndpoint, IMcpServer { + internal static Implementation DefaultImplementation { get; } = new() + { + Name = DefaultAssemblyName.Name ?? nameof(McpServer), + Version = DefaultAssemblyName.Version?.ToString() ?? "1.0.0", + }; + private readonly EventHandler? _toolsChangedDelegate; private readonly EventHandler? _promptsChangedDelegate; @@ -32,9 +38,11 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? Throw.IfNull(transport); Throw.IfNull(options); + options ??= new(); + ServerOptions = options; Services = serviceProvider; - _endpointName = $"Server ({options.ServerInfo.Name} {options.ServerInfo.Version})"; + _endpointName = $"Server ({options.ServerInfo?.Name ?? DefaultImplementation.Name} {options.ServerInfo?.Version ?? DefaultImplementation.Version})"; _toolsChangedDelegate = delegate { @@ -158,7 +166,7 @@ private void SetInitializeHandler(McpServerOptions options) { ProtocolVersion = options.ProtocolVersion, Instructions = options.ServerInstructions, - ServerInfo = options.ServerInfo, + ServerInfo = options.ServerInfo ?? DefaultImplementation, Capabilities = ServerCapabilities ?? new(), }); }, diff --git a/src/ModelContextProtocol/Server/McpServerExtensions.cs b/src/ModelContextProtocol/Server/McpServerExtensions.cs index b59992c9..06ec596c 100644 --- a/src/ModelContextProtocol/Server/McpServerExtensions.cs +++ b/src/ModelContextProtocol/Server/McpServerExtensions.cs @@ -1,5 +1,4 @@ using Microsoft.Extensions.AI; -using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Utils; diff --git a/src/ModelContextProtocol/Server/McpServerOptions.cs b/src/ModelContextProtocol/Server/McpServerOptions.cs index f17ba15d..06b0a260 100644 --- a/src/ModelContextProtocol/Server/McpServerOptions.cs +++ b/src/ModelContextProtocol/Server/McpServerOptions.cs @@ -14,7 +14,7 @@ public class McpServerOptions /// /// Information about this server implementation. /// - public required Implementation ServerInfo { get; set; } + public Implementation? ServerInfo { get; set; } /// /// Server capabilities to advertise to the server. diff --git a/src/ModelContextProtocol/Shared/McpEndpoint.cs b/src/ModelContextProtocol/Shared/McpEndpoint.cs index 8587da98..c26af4b1 100644 --- a/src/ModelContextProtocol/Shared/McpEndpoint.cs +++ b/src/ModelContextProtocol/Shared/McpEndpoint.cs @@ -6,7 +6,7 @@ using ModelContextProtocol.Server; using ModelContextProtocol.Utils; using System.Diagnostics.CodeAnalysis; -using System.Text.Json.Serialization.Metadata; +using System.Reflection; namespace ModelContextProtocol.Shared; @@ -19,6 +19,9 @@ namespace ModelContextProtocol.Shared; /// internal abstract class McpEndpoint : IAsyncDisposable { + /// Cached naming information used for name/version when none is specified. + internal static AssemblyName DefaultAssemblyName { get; } = (Assembly.GetEntryAssembly() ?? Assembly.GetExecutingAssembly()).GetName(); + private McpSession? _session; private CancellationTokenSource? _sessionCts; diff --git a/src/ModelContextProtocol/Shared/McpSession.cs b/src/ModelContextProtocol/Shared/McpSession.cs index dae92686..3d06f7fc 100644 --- a/src/ModelContextProtocol/Shared/McpSession.cs +++ b/src/ModelContextProtocol/Shared/McpSession.cs @@ -3,7 +3,6 @@ using ModelContextProtocol.Logging; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; -using ModelContextProtocol.Server; using ModelContextProtocol.Utils; using ModelContextProtocol.Utils.Json; using System.Collections.Concurrent; diff --git a/tests/ModelContextProtocol.TestServer/Program.cs b/tests/ModelContextProtocol.TestServer/Program.cs index 830a042a..ba6cca21 100644 --- a/tests/ModelContextProtocol.TestServer/Program.cs +++ b/tests/ModelContextProtocol.TestServer/Program.cs @@ -38,7 +38,6 @@ private static async Task Main(string[] args) McpServerOptions options = new() { - ServerInfo = new Implementation() { Name = "TestServer", Version = "1.0.0" }, Capabilities = new ServerCapabilities() { Tools = ConfigureTools(), diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index e4bd996b..43ae3207 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -25,7 +25,6 @@ private static void ConfigureSerilog(ILoggingBuilder loggingBuilder) private static void ConfigureOptions(McpServerOptions options) { - options.ServerInfo = new Implementation() { Name = "TestServer", Version = "1.0.0" }; options.Capabilities = new ServerCapabilities() { Tools = new(), diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs index ae58023f..aa2f773e 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs @@ -11,29 +11,24 @@ namespace ModelContextProtocol.Tests.Client; public class McpClientFactoryTests { - private readonly McpClientOptions _defaultOptions = new() - { - ClientInfo = new() { Name = "TestClient", Version = "1.0.0" } - }; - [Fact] public async Task CreateAsync_WithInvalidArgs_Throws() { - await Assert.ThrowsAsync("serverConfig", () => McpClientFactory.CreateAsync((McpServerConfig)null!, _defaultOptions, cancellationToken: TestContext.Current.CancellationToken)); + await Assert.ThrowsAsync("serverConfig", () => McpClientFactory.CreateAsync((McpServerConfig)null!, cancellationToken: TestContext.Current.CancellationToken)); await Assert.ThrowsAsync("serverConfig", () => McpClientFactory.CreateAsync(new McpServerConfig() { Name = "name", Id = "id", TransportType = "somethingunsupported", - }, _defaultOptions, cancellationToken: TestContext.Current.CancellationToken)); + }, cancellationToken: TestContext.Current.CancellationToken)); await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(new McpServerConfig() { Name = "name", Id = "id", TransportType = TransportTypes.StdIo, - }, _defaultOptions, (_, __) => null!, cancellationToken: TestContext.Current.CancellationToken)); + }, createTransportFunc: (_, __) => null!, cancellationToken: TestContext.Current.CancellationToken)); } [Fact] @@ -78,8 +73,7 @@ public async Task CreateAsync_WithValidStdioConfig_CreatesNewClient() // Act await using var client = await McpClientFactory.CreateAsync( serverConfig, - _defaultOptions, - (_, __) => new NopTransport(), + createTransportFunc: (_, __) => new NopTransport(), cancellationToken: TestContext.Current.CancellationToken); // Assert @@ -102,8 +96,7 @@ public async Task CreateAsync_WithNoTransportOptions_CreatesNewClient() // Act await using var client = await McpClientFactory.CreateAsync( serverConfig, - _defaultOptions, - (_, __) => new NopTransport(), + createTransportFunc: (_, __) => new NopTransport(), cancellationToken: TestContext.Current.CancellationToken); // Assert @@ -126,8 +119,7 @@ public async Task CreateAsync_WithValidSseConfig_CreatesNewClient() // Act await using var client = await McpClientFactory.CreateAsync( serverConfig, - _defaultOptions, - (_, __) => new NopTransport(), + createTransportFunc: (_, __) => new NopTransport(), cancellationToken: TestContext.Current.CancellationToken); // Assert @@ -157,8 +149,7 @@ public async Task CreateAsync_WithSse_CreatesCorrectTransportOptions() // Act await using var client = await McpClientFactory.CreateAsync( serverConfig, - _defaultOptions, - (_, __) => new NopTransport(), + createTransportFunc: (_, __) => new NopTransport(), cancellationToken: TestContext.Current.CancellationToken); // Assert @@ -186,7 +177,7 @@ public async Task McpFactory_WithInvalidTransportOptions_ThrowsFormatException(s }; // act & assert - await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(config, _defaultOptions, cancellationToken: TestContext.Current.CancellationToken)); + await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(config, cancellationToken: TestContext.Current.CancellationToken)); } [Theory] @@ -205,11 +196,6 @@ public async Task CreateAsync_WithCapabilitiesOptions(Type transportType) var clientOptions = new McpClientOptions { - ClientInfo = new Implementation - { - Name = "TestClient", - Version = "1.0.0.0" - }, Capabilities = new ClientCapabilities { Sampling = new SamplingCapability diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index d808e40e..6f0c80f1 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -5,10 +5,7 @@ using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Tests.Utils; using OpenAI; -using System.Text.Encodings.Web; using System.Text.Json; -using System.Text.Json.Serialization; -using System.Text.Json.Serialization.Metadata; namespace ModelContextProtocol.Tests; @@ -367,7 +364,6 @@ public async Task Sampling_Stdio(string clientId) int samplingHandlerCalls = 0; await using var client = await _fixture.CreateClientAsync(clientId, new() { - ClientInfo = new() { Name = "Sampling_Stdio", Version = "1.0.0" }, Capabilities = new() { Sampling = new() @@ -532,7 +528,6 @@ public async Task SamplingViaChatClient_RequestResponseProperlyPropagated() .CreateSamplingHandler(); await using var client = await McpClientFactory.CreateAsync(_fixture.EverythingServerConfig, new() { - ClientInfo = new() { Name = nameof(SamplingViaChatClient_RequestResponseProperlyPropagated), Version = "1.0.0" }, Capabilities = new() { Sampling = new() diff --git a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs index 583ae274..fd93eff6 100644 --- a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs +++ b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs @@ -1,6 +1,5 @@ using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Transport; -using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using OpenTelemetry.Trace; using System.Diagnostics; @@ -49,7 +48,6 @@ private static async Task RunConnected(Func action await using (IMcpServer server = McpServerFactory.Create(serverTransport, new() { - ServerInfo = new Implementation { Name = "TestServer", Version = "1.0.0" }, Capabilities = new() { Tools = new() diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs index ae640ecc..034a30bd 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs @@ -1,5 +1,4 @@ -using ModelContextProtocol.Protocol.Types; -using ModelContextProtocol.Server; +using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Utils; namespace ModelContextProtocol.Tests.Server; @@ -13,7 +12,6 @@ public McpServerFactoryTests(ITestOutputHelper testOutputHelper) { _options = new McpServerOptions { - ServerInfo = new Implementation { Name = "TestServer", Version = "1.0" }, ProtocolVersion = "1.0", InitializationTimeout = TimeSpan.FromSeconds(30) }; diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 0a5bec09..fb462748 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -24,7 +24,6 @@ private static McpServerOptions CreateOptions(ServerCapabilities? capabilities = { return new McpServerOptions { - ServerInfo = new Implementation { Name = "TestServer", Version = "1.0" }, ProtocolVersion = "2024", InitializationTimeout = TimeSpan.FromSeconds(30), Capabilities = capabilities, @@ -189,8 +188,8 @@ await Can_Handle_Requests( { var result = JsonSerializer.Deserialize(response); Assert.NotNull(result); - Assert.Equal("TestServer", result.ServerInfo.Name); - Assert.Equal("1.0", result.ServerInfo.Version); + Assert.Equal("ModelContextProtocol.Tests", result.ServerInfo.Name); + Assert.Equal("1.0.0.0", result.ServerInfo.Version); Assert.Equal("2024", result.ProtocolVersion); }); } diff --git a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs index 11caaef8..91e99898 100644 --- a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs @@ -29,11 +29,6 @@ public async Task ConnectAndReceiveMessage_InMemoryServer() await using InMemoryTestSseServer server = new(CreatePortNumber(), LoggerFactory.CreateLogger()); await server.StartAsync(); - var defaultOptions = new McpClientOptions - { - ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" } - }; - var defaultConfig = new McpServerConfig { Id = "test_server", @@ -46,7 +41,6 @@ public async Task ConnectAndReceiveMessage_InMemoryServer() // Act await using var client = await McpClientFactory.CreateAsync( defaultConfig, - defaultOptions, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); @@ -120,11 +114,6 @@ public async Task Sampling_Sse_EverythingServer() int samplingHandlerCalls = 0; var defaultOptions = new McpClientOptions { - ClientInfo = new() - { - Name = "IntegrationTestClient", - Version = "1.0.0" - }, Capabilities = new() { Sampling = new() @@ -175,11 +164,6 @@ public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventU server.UseFullUrlForEndpointEvent = true; await server.StartAsync(); - var defaultOptions = new McpClientOptions - { - ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" } - }; - var defaultConfig = new McpServerConfig { Id = "test_server", @@ -192,7 +176,6 @@ public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventU // Act await using var client = await McpClientFactory.CreateAsync( defaultConfig, - defaultOptions, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs b/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs index 238c7747..1720bc69 100644 --- a/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs @@ -1,5 +1,4 @@ -using ModelContextProtocol.Client; -using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Test.Utils; using ModelContextProtocol.Tests.Utils; using ModelContextProtocol.TestSseServer; @@ -32,11 +31,6 @@ public SseServerIntegrationTestFixture() _serverTask = Program.MainAsync([port.ToString()], new XunitLoggerProvider(_delegatingTestOutputHelper), _stopCts.Token); } - public static McpClientOptions CreateDefaultClientOptions() => new() - { - ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" }, - }; - public void Initialize(ITestOutputHelper output) { _delegatingTestOutputHelper.CurrentTestOutputHelper = output; diff --git a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs index a71766e2..eaead4e0 100644 --- a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs @@ -27,7 +27,7 @@ private Task GetClientAsync(McpClientOptions? options = null) { return McpClientFactory.CreateAsync( _fixture.DefaultConfig, - options ?? SseServerIntegrationTestFixture.CreateDefaultClientOptions(), + options, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); } @@ -220,8 +220,8 @@ public async Task Sampling_Sse_TestServer() // Set up the sampling handler int samplingHandlerCalls = 0; #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously - var options = SseServerIntegrationTestFixture.CreateDefaultClientOptions(); - options.Capabilities ??= new(); + McpClientOptions options = new(); + options.Capabilities = new(); options.Capabilities.Sampling ??= new(); options.Capabilities.Sampling.SamplingHandler = async (_, _, _) => { diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs index 33793555..7e79356e 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs @@ -1,7 +1,6 @@ using Microsoft.Extensions.Options; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; -using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Utils; using ModelContextProtocol.Utils.Json; @@ -21,11 +20,6 @@ public StdioServerTransportTests(ITestOutputHelper testOutputHelper) { _serverOptions = new McpServerOptions { - ServerInfo = new Implementation - { - Name = "Test Server", - Version = "1.0" - }, ProtocolVersion = "2.0", InitializationTimeout = TimeSpan.FromSeconds(10), ServerInstructions = "Test Instructions" @@ -49,8 +43,6 @@ public void Constructor_Throws_For_Null_Options() Assert.Throws("serverOptions", () => new StdioServerTransport((IOptions)null!)); Assert.Throws("serverOptions", () => new StdioServerTransport((McpServerOptions)null!)); - Assert.Throws("serverOptions.ServerInfo", () => new StdioServerTransport(new McpServerOptions() { ServerInfo = null! })); - Assert.Throws("serverOptions.ServerInfo.Name", () => new StdioServerTransport(new McpServerOptions() { ServerInfo = new() { Name = null!, Version = "" } })); } [Fact]