diff --git a/samples/QuickstartWeatherServer/Tools/WeatherTools.cs b/samples/QuickstartWeatherServer/Tools/WeatherTools.cs index 8463e350..9f04ff25 100644 --- a/samples/QuickstartWeatherServer/Tools/WeatherTools.cs +++ b/samples/QuickstartWeatherServer/Tools/WeatherTools.cs @@ -1,7 +1,6 @@ using ModelContextProtocol; using ModelContextProtocol.Server; using System.ComponentModel; -using System.Net.Http.Json; using System.Text.Json; namespace QuickstartWeatherServer.Tools; diff --git a/src/ModelContextProtocol/Client/McpClient.cs b/src/ModelContextProtocol/Client/McpClient.cs index a38edf6d..64a89680 100644 --- a/src/ModelContextProtocol/Client/McpClient.cs +++ b/src/ModelContextProtocol/Client/McpClient.cs @@ -12,6 +12,12 @@ namespace ModelContextProtocol.Client; /// internal sealed class McpClient : McpEndpoint, IMcpClient { + private static Implementation DefaultImplementation { get; } = new() + { + Name = DefaultAssemblyName.Name ?? nameof(McpClient), + Version = DefaultAssemblyName.Version?.ToString() ?? "1.0.0", + }; + private readonly IClientTransport _clientTransport; private readonly McpClientOptions _options; @@ -29,43 +35,53 @@ internal sealed class McpClient : McpEndpoint, IMcpClient /// Options for the client, defining protocol version and capabilities. /// The server configuration. /// The logger factory. - public McpClient(IClientTransport clientTransport, McpClientOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory) + public McpClient(IClientTransport clientTransport, McpClientOptions? options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory) : base(loggerFactory) { + options ??= new(); + _clientTransport = clientTransport; _options = options; EndpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})"; - if (options.Capabilities?.Sampling is { } samplingCapability) + if (options.Capabilities is { } capabilities) { - if (samplingCapability.SamplingHandler is not { } samplingHandler) + if (capabilities.NotificationHandlers is { } notificationHandlers) { - throw new InvalidOperationException($"Sampling capability was set but it did not provide a handler."); + NotificationHandlers.AddRange(notificationHandlers); } - SetRequestHandler( - RequestMethods.SamplingCreateMessage, - (request, cancellationToken) => samplingHandler( - request, - request?.Meta?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, - cancellationToken), - McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, - McpJsonUtilities.JsonContext.Default.CreateMessageResult); - } - - if (options.Capabilities?.Roots is { } rootsCapability) - { - if (rootsCapability.RootsHandler is not { } rootsHandler) + if (capabilities.Sampling is { } samplingCapability) { - throw new InvalidOperationException($"Roots capability was set but it did not provide a handler."); + if (samplingCapability.SamplingHandler is not { } samplingHandler) + { + throw new InvalidOperationException($"Sampling capability was set but it did not provide a handler."); + } + + RequestHandlers.Set( + RequestMethods.SamplingCreateMessage, + (request, cancellationToken) => samplingHandler( + request, + request?.Meta?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, + cancellationToken), + McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, + McpJsonUtilities.JsonContext.Default.CreateMessageResult); } - SetRequestHandler( - RequestMethods.RootsList, - rootsHandler, - McpJsonUtilities.JsonContext.Default.ListRootsRequestParams, - McpJsonUtilities.JsonContext.Default.ListRootsResult); + if (capabilities.Roots is { } rootsCapability) + { + if (rootsCapability.RootsHandler is not { } rootsHandler) + { + throw new InvalidOperationException($"Roots capability was set but it did not provide a handler."); + } + + RequestHandlers.Set( + RequestMethods.RootsList, + rootsHandler, + McpJsonUtilities.JsonContext.Default.ListRootsRequestParams, + McpJsonUtilities.JsonContext.Default.ListRootsResult); + } } } @@ -96,20 +112,20 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); initializationCts.CancelAfter(_options.InitializationTimeout); - try - { - // Send initialize request - var initializeResponse = await this.SendRequestAsync( - RequestMethods.Initialize, - new InitializeRequestParams - { - ProtocolVersion = _options.ProtocolVersion, - Capabilities = _options.Capabilities ?? new ClientCapabilities(), - ClientInfo = _options.ClientInfo - }, - McpJsonUtilities.JsonContext.Default.InitializeRequestParams, - McpJsonUtilities.JsonContext.Default.InitializeResult, - cancellationToken: initializationCts.Token).ConfigureAwait(false); + try + { + // Send initialize request + var initializeResponse = await this.SendRequestAsync( + RequestMethods.Initialize, + new InitializeRequestParams + { + ProtocolVersion = _options.ProtocolVersion, + Capabilities = _options.Capabilities ?? new ClientCapabilities(), + ClientInfo = _options.ClientInfo ?? DefaultImplementation, + }, + McpJsonUtilities.JsonContext.Default.InitializeRequestParams, + McpJsonUtilities.JsonContext.Default.InitializeResult, + cancellationToken: initializationCts.Token).ConfigureAwait(false); // Store server information _logger.ServerCapabilitiesReceived(EndpointName, diff --git a/src/ModelContextProtocol/Client/McpClientFactory.cs b/src/ModelContextProtocol/Client/McpClientFactory.cs index c1bb3770..751df190 100644 --- a/src/ModelContextProtocol/Client/McpClientFactory.cs +++ b/src/ModelContextProtocol/Client/McpClientFactory.cs @@ -5,30 +5,12 @@ using ModelContextProtocol.Utils; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; -using System.Reflection; namespace ModelContextProtocol.Client; /// Provides factory methods for creating MCP clients. public static class McpClientFactory { - /// Default client options to use when none are supplied. - private static readonly McpClientOptions s_defaultClientOptions = CreateDefaultClientOptions(); - - /// Creates default client options to use when no options are supplied. - private static McpClientOptions CreateDefaultClientOptions() - { - var asmName = (Assembly.GetEntryAssembly() ?? Assembly.GetExecutingAssembly()).GetName(); - return new() - { - ClientInfo = new() - { - Name = asmName.Name ?? "McpClient", - Version = asmName.Version?.ToString() ?? "1.0.0", - }, - }; - } - /// Creates an , connecting it to the specified server. /// Configuration for the target server to which the client should connect. /// @@ -52,7 +34,6 @@ public static async Task CreateAsync( { Throw.IfNull(serverConfig); - clientOptions ??= s_defaultClientOptions; createTransportFunc ??= CreateTransport; string endpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})"; diff --git a/src/ModelContextProtocol/Client/McpClientOptions.cs b/src/ModelContextProtocol/Client/McpClientOptions.cs index de61f18f..5fe81ede 100644 --- a/src/ModelContextProtocol/Client/McpClientOptions.cs +++ b/src/ModelContextProtocol/Client/McpClientOptions.cs @@ -12,7 +12,7 @@ public class McpClientOptions /// /// Information about this client implementation. /// - public required Implementation ClientInfo { get; set; } + public Implementation? ClientInfo { get; set; } /// /// Client capabilities to advertise to the server. diff --git a/src/ModelContextProtocol/Client/McpClientTool.cs b/src/ModelContextProtocol/Client/McpClientTool.cs index 58cb02d5..9dfee5a0 100644 --- a/src/ModelContextProtocol/Client/McpClientTool.cs +++ b/src/ModelContextProtocol/Client/McpClientTool.cs @@ -1,6 +1,5 @@ using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Utils.Json; -using ModelContextProtocol.Utils; using Microsoft.Extensions.AI; using System.Text.Json; diff --git a/src/ModelContextProtocol/Configuration/McpServerOptionsSetup.cs b/src/ModelContextProtocol/Configuration/McpServerOptionsSetup.cs index 2ef0300e..c778fe81 100644 --- a/src/ModelContextProtocol/Configuration/McpServerOptionsSetup.cs +++ b/src/ModelContextProtocol/Configuration/McpServerOptionsSetup.cs @@ -1,5 +1,4 @@ -using System.Reflection; -using ModelContextProtocol.Server; +using ModelContextProtocol.Server; using Microsoft.Extensions.Options; using ModelContextProtocol.Utils; @@ -25,18 +24,6 @@ public void Configure(McpServerOptions options) { Throw.IfNull(options); - // Configure the option's server information based on the current process, - // if it otherwise lacks server information. - if (options.ServerInfo is not { } serverInfo) - { - var assemblyName = Assembly.GetEntryAssembly()?.GetName(); - options.ServerInfo = new() - { - Name = assemblyName?.Name ?? "McpServer", - Version = assemblyName?.Version?.ToString() ?? "1.0.0", - }; - } - // Collect all of the provided tools into a tools collection. If the options already has // a collection, add to it, otherwise create a new one. We want to maintain the identity // of an existing collection in case someone has provided their own derived type, wants diff --git a/src/ModelContextProtocol/IMcpEndpoint.cs b/src/ModelContextProtocol/IMcpEndpoint.cs index a053ad64..95d6dbc5 100644 --- a/src/ModelContextProtocol/IMcpEndpoint.cs +++ b/src/ModelContextProtocol/IMcpEndpoint.cs @@ -15,20 +15,4 @@ public interface IMcpEndpoint : IAsyncDisposable /// The message. /// The to monitor for cancellation requests. The default is . Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default); - - /// - /// Adds a handler for server notifications of a specific method. - /// - /// The notification method to handle. - /// The async handler function to process notifications. - /// - /// - /// Each method may have multiple handlers. Adding a handler for a method that already has one - /// will not replace the existing handler. - /// - /// - /// provides constants for common notification methods. - /// - /// - void AddNotificationHandler(string method, Func handler); } diff --git a/src/ModelContextProtocol/Protocol/Transport/McpTransportException.cs b/src/ModelContextProtocol/Protocol/Transport/McpTransportException.cs index bc49376d..d84de632 100644 --- a/src/ModelContextProtocol/Protocol/Transport/McpTransportException.cs +++ b/src/ModelContextProtocol/Protocol/Transport/McpTransportException.cs @@ -30,7 +30,7 @@ public McpTransportException(string message) /// /// The message that describes the error. /// The exception that is the cause of the current exception. - public McpTransportException(string message, Exception innerException) + public McpTransportException(string message, Exception? innerException) : base(message, innerException) { } diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs index af304c89..b86c886b 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs @@ -21,10 +21,22 @@ public StdioClientSessionTransport(StdioClientTransportOptions options, Process /// public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) { - if (_process.HasExited) + Exception? processException = null; + bool hasExited = false; + try + { + hasExited = _process.HasExited; + } + catch (Exception e) + { + processException = e; + hasExited = true; + } + + if (hasExited) { Logger.TransportNotConnected(EndpointName); - throw new McpTransportException("Transport is not connected"); + throw new McpTransportException("Transport is not connected", processException); } await base.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); @@ -33,7 +45,7 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio /// protected override ValueTask CleanupAsync(CancellationToken cancellationToken) { - StdioClientTransport.DisposeProcess(_process, processStarted: true, Logger, _options.ShutdownTimeout, EndpointName); + StdioClientTransport.DisposeProcess(_process, processRunning: true, Logger, _options.ShutdownTimeout, EndpointName); return base.CleanupAsync(cancellationToken); } diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs index 774677f1..11a0c137 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs @@ -129,13 +129,25 @@ public async Task ConnectAsync(CancellationToken cancellationToken = } internal static void DisposeProcess( - Process? process, bool processStarted, ILogger logger, TimeSpan shutdownTimeout, string endpointName) + Process? process, bool processRunning, ILogger logger, TimeSpan shutdownTimeout, string endpointName) { if (process is not null) { + if (processRunning) + { + try + { + processRunning = !process.HasExited; + } + catch + { + processRunning = false; + } + } + try { - if (processStarted && !process.HasExited) + if (processRunning) { // Wait for the process to exit. // Kill the while process tree because the process may spawn child processes diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs index 58077dbb..b3ae2e41 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs @@ -70,9 +70,7 @@ public StdioServerTransport(string serverName, ILoggerFactory? loggerFactory = n private static string GetServerName(McpServerOptions serverOptions) { Throw.IfNull(serverOptions); - Throw.IfNull(serverOptions.ServerInfo); - Throw.IfNull(serverOptions.ServerInfo.Name); - return serverOptions.ServerInfo.Name; + return serverOptions.ServerInfo?.Name ?? McpServer.DefaultImplementation.Name; } } diff --git a/src/ModelContextProtocol/Protocol/Types/Capabilities.cs b/src/ModelContextProtocol/Protocol/Types/Capabilities.cs index c0cf4197..a29a530a 100644 --- a/src/ModelContextProtocol/Protocol/Types/Capabilities.cs +++ b/src/ModelContextProtocol/Protocol/Types/Capabilities.cs @@ -1,4 +1,5 @@ -using ModelContextProtocol.Server; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Server; using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol.Types; @@ -26,6 +27,14 @@ public class ClientCapabilities /// [JsonPropertyName("sampling")] public SamplingCapability? Sampling { get; set; } + + /// Gets or sets notification handlers to register with the client. + /// + /// When constructed, the client will enumerate these handlers, which may contain multiple handlers per key. + /// The client will not re-enumerate the sequence. + /// + [JsonIgnore] + public IEnumerable>>? NotificationHandlers { get; set; } } /// diff --git a/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs index a5eec7a1..273251e6 100644 --- a/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs +++ b/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs @@ -1,6 +1,4 @@ -using ModelContextProtocol.Protocol.Messages; - -namespace ModelContextProtocol.Protocol.Types; +namespace ModelContextProtocol.Protocol.Types; /// /// A request from the server to get a list of root URIs from the client. diff --git a/src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs b/src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs index 296e4e62..9f8e3ac8 100644 --- a/src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs +++ b/src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs @@ -1,4 +1,5 @@ -using System.Text.Json.Serialization; +using ModelContextProtocol.Protocol.Messages; +using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol.Types; @@ -37,4 +38,12 @@ public class ServerCapabilities /// [JsonPropertyName("tools")] public ToolsCapability? Tools { get; set; } + + /// Gets or sets notification handlers to register with the server. + /// + /// When constructed, the server will enumerate these handlers, which may contain multiple handlers per key. + /// The server will not re-enumerate the sequence. + /// + [JsonIgnore] + public IEnumerable>>? NotificationHandlers { get; set; } } diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index 764736e2..9b6898d8 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -11,6 +11,12 @@ namespace ModelContextProtocol.Server; /// internal sealed class McpServer : McpEndpoint, IMcpServer { + internal static Implementation DefaultImplementation { get; } = new() + { + Name = DefaultAssemblyName.Name ?? nameof(McpServer), + Version = DefaultAssemblyName.Version?.ToString() ?? "1.0.0", + }; + private readonly EventHandler? _toolsChangedDelegate; private readonly EventHandler? _promptsChangedDelegate; @@ -32,9 +38,11 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? Throw.IfNull(transport); Throw.IfNull(options); + options ??= new(); + ServerOptions = options; Services = serviceProvider; - _endpointName = $"Server ({options.ServerInfo.Name} {options.ServerInfo.Version})"; + _endpointName = $"Server ({options.ServerInfo?.Name ?? DefaultImplementation.Name} {options.ServerInfo?.Version ?? DefaultImplementation.Version})"; _toolsChangedDelegate = delegate { @@ -51,7 +59,7 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? }); }; - AddNotificationHandler(NotificationMethods.InitializedNotification, _ => + NotificationHandlers.Add(NotificationMethods.InitializedNotification, _ => { if (ServerOptions.Capabilities?.Tools?.ToolCollection is { } tools) { @@ -66,6 +74,11 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? return Task.CompletedTask; }); + if (options.Capabilities?.NotificationHandlers is { } notificationHandlers) + { + NotificationHandlers.AddRange(notificationHandlers); + } + SetToolsHandler(options); SetInitializeHandler(options); SetCompletionHandler(options); @@ -131,7 +144,7 @@ public override async ValueTask DisposeUnsynchronizedAsync() private void SetPingHandler() { - SetRequestHandler(RequestMethods.Ping, + RequestHandlers.Set(RequestMethods.Ping, (request, _) => Task.FromResult(new PingResult()), McpJsonUtilities.JsonContext.Default.JsonNode, McpJsonUtilities.JsonContext.Default.PingResult); @@ -139,7 +152,7 @@ private void SetPingHandler() private void SetInitializeHandler(McpServerOptions options) { - SetRequestHandler(RequestMethods.Initialize, + RequestHandlers.Set(RequestMethods.Initialize, (request, _) => { ClientCapabilities = request?.Capabilities ?? new(); @@ -153,7 +166,7 @@ private void SetInitializeHandler(McpServerOptions options) { ProtocolVersion = options.ProtocolVersion, Instructions = options.ServerInstructions, - ServerInfo = options.ServerInfo, + ServerInfo = options.ServerInfo ?? DefaultImplementation, Capabilities = ServerCapabilities ?? new(), }); }, @@ -164,7 +177,7 @@ private void SetInitializeHandler(McpServerOptions options) private void SetCompletionHandler(McpServerOptions options) { // This capability is not optional, so return an empty result if there is no handler. - SetRequestHandler(RequestMethods.CompletionComplete, + RequestHandlers.Set(RequestMethods.CompletionComplete, options.GetCompletionHandler is { } handler ? (request, ct) => handler(new(this, request), ct) : (request, ct) => Task.FromResult(new CompleteResult() { Completion = new() { Values = [], Total = 0, HasMore = false } }), @@ -190,20 +203,20 @@ private void SetResourcesHandler(McpServerOptions options) listResourcesHandler ??= (static (_, _) => Task.FromResult(new ListResourcesResult())); - SetRequestHandler( + RequestHandlers.Set( RequestMethods.ResourcesList, (request, ct) => listResourcesHandler(new(this, request), ct), McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, McpJsonUtilities.JsonContext.Default.ListResourcesResult); - SetRequestHandler( + RequestHandlers.Set( RequestMethods.ResourcesRead, (request, ct) => readResourceHandler(new(this, request), ct), McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, McpJsonUtilities.JsonContext.Default.ReadResourceResult); listResourceTemplatesHandler ??= (static (_, _) => Task.FromResult(new ListResourceTemplatesResult())); - SetRequestHandler( + RequestHandlers.Set( RequestMethods.ResourcesTemplatesList, (request, ct) => listResourceTemplatesHandler(new(this, request), ct), McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, @@ -221,13 +234,13 @@ private void SetResourcesHandler(McpServerOptions options) throw new McpException("Resources capability was enabled with subscribe support, but SubscribeToResources and/or UnsubscribeFromResources handlers were not specified."); } - SetRequestHandler( + RequestHandlers.Set( RequestMethods.ResourcesSubscribe, (request, ct) => subscribeHandler(new(this, request), ct), McpJsonUtilities.JsonContext.Default.SubscribeRequestParams, McpJsonUtilities.JsonContext.Default.EmptyResult); - SetRequestHandler( + RequestHandlers.Set( RequestMethods.ResourcesUnsubscribe, (request, ct) => unsubscribeHandler(new(this, request), ct), McpJsonUtilities.JsonContext.Default.UnsubscribeRequestParams, @@ -314,13 +327,13 @@ await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(fals } } - SetRequestHandler( + RequestHandlers.Set( RequestMethods.PromptsList, (request, ct) => listPromptsHandler(new(this, request), ct), McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, McpJsonUtilities.JsonContext.Default.ListPromptsResult); - SetRequestHandler( + RequestHandlers.Set( RequestMethods.PromptsGet, (request, ct) => getPromptHandler(new(this, request), ct), McpJsonUtilities.JsonContext.Default.GetPromptRequestParams, @@ -407,13 +420,13 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) } } - SetRequestHandler( + RequestHandlers.Set( RequestMethods.ToolsList, (request, ct) => listToolsHandler(new(this, request), ct), McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, McpJsonUtilities.JsonContext.Default.ListToolsResult); - SetRequestHandler( + RequestHandlers.Set( RequestMethods.ToolsCall, (request, ct) => callToolHandler(new(this, request), ct), McpJsonUtilities.JsonContext.Default.CallToolRequestParams, @@ -432,7 +445,7 @@ private void SetSetLoggingLevelHandler(McpServerOptions options) throw new McpException("Logging capability was enabled, but SetLoggingLevelHandler was not specified."); } - SetRequestHandler( + RequestHandlers.Set( RequestMethods.LoggingSetLevel, (request, ct) => setLoggingLevelHandler(new(this, request), ct), McpJsonUtilities.JsonContext.Default.SetLevelRequestParams, diff --git a/src/ModelContextProtocol/Server/McpServerExtensions.cs b/src/ModelContextProtocol/Server/McpServerExtensions.cs index b59992c9..06ec596c 100644 --- a/src/ModelContextProtocol/Server/McpServerExtensions.cs +++ b/src/ModelContextProtocol/Server/McpServerExtensions.cs @@ -1,5 +1,4 @@ using Microsoft.Extensions.AI; -using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Utils; diff --git a/src/ModelContextProtocol/Server/McpServerOptions.cs b/src/ModelContextProtocol/Server/McpServerOptions.cs index f17ba15d..06b0a260 100644 --- a/src/ModelContextProtocol/Server/McpServerOptions.cs +++ b/src/ModelContextProtocol/Server/McpServerOptions.cs @@ -14,7 +14,7 @@ public class McpServerOptions /// /// Information about this server implementation. /// - public required Implementation ServerInfo { get; set; } + public Implementation? ServerInfo { get; set; } /// /// Server capabilities to advertise to the server. diff --git a/src/ModelContextProtocol/Shared/McpEndpoint.cs b/src/ModelContextProtocol/Shared/McpEndpoint.cs index 8b50a805..c26af4b1 100644 --- a/src/ModelContextProtocol/Shared/McpEndpoint.cs +++ b/src/ModelContextProtocol/Shared/McpEndpoint.cs @@ -6,7 +6,7 @@ using ModelContextProtocol.Server; using ModelContextProtocol.Utils; using System.Diagnostics.CodeAnalysis; -using System.Text.Json.Serialization.Metadata; +using System.Reflection; namespace ModelContextProtocol.Shared; @@ -19,8 +19,8 @@ namespace ModelContextProtocol.Shared; /// internal abstract class McpEndpoint : IAsyncDisposable { - private readonly RequestHandlers _requestHandlers = []; - private readonly NotificationHandlers _notificationHandlers = []; + /// Cached naming information used for name/version when none is specified. + internal static AssemblyName DefaultAssemblyName { get; } = (Assembly.GetEntryAssembly() ?? Assembly.GetExecutingAssembly()).GetName(); private McpSession? _session; private CancellationTokenSource? _sessionCts; @@ -39,16 +39,9 @@ protected McpEndpoint(ILoggerFactory? loggerFactory = null) _logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance; } - protected void SetRequestHandler( - string method, - Func> handler, - JsonTypeInfo requestTypeInfo, - JsonTypeInfo responseTypeInfo) + protected RequestHandlers RequestHandlers { get; } = []; - => _requestHandlers.Set(method, handler, requestTypeInfo, responseTypeInfo); - - public void AddNotificationHandler(string method, Func handler) - => _notificationHandlers.Add(method, handler); + protected NotificationHandlers NotificationHandlers { get; } = []; public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) => GetSessionOrThrow().SendRequestAsync(request, cancellationToken); @@ -70,7 +63,7 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella protected void StartSession(ITransport sessionTransport) { _sessionCts = new CancellationTokenSource(); - _session = new McpSession(this is IMcpServer, sessionTransport, EndpointName, _requestHandlers, _notificationHandlers, _logger); + _session = new McpSession(this is IMcpServer, sessionTransport, EndpointName, RequestHandlers, NotificationHandlers, _logger); MessageProcessingTask = _session.ProcessMessagesAsync(_sessionCts.Token); } diff --git a/src/ModelContextProtocol/Shared/McpSession.cs b/src/ModelContextProtocol/Shared/McpSession.cs index dae92686..3d06f7fc 100644 --- a/src/ModelContextProtocol/Shared/McpSession.cs +++ b/src/ModelContextProtocol/Shared/McpSession.cs @@ -3,7 +3,6 @@ using ModelContextProtocol.Logging; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; -using ModelContextProtocol.Server; using ModelContextProtocol.Utils; using ModelContextProtocol.Utils.Json; using System.Collections.Concurrent; diff --git a/src/ModelContextProtocol/Shared/NotificationHandlers.cs b/src/ModelContextProtocol/Shared/NotificationHandlers.cs index 1fdb0578..c3b5dfe7 100644 --- a/src/ModelContextProtocol/Shared/NotificationHandlers.cs +++ b/src/ModelContextProtocol/Shared/NotificationHandlers.cs @@ -1,16 +1,28 @@ using ModelContextProtocol.Protocol.Messages; -using System.Collections.Concurrent; namespace ModelContextProtocol.Shared; -internal sealed class NotificationHandlers : ConcurrentDictionary>> +internal sealed class NotificationHandlers : Dictionary>> { + /// Adds a notification handler as part of configuring the endpoint. + /// This method is not thread-safe and should only be used serially as part of configuring the instance. public void Add(string method, Func handler) { - var handlers = GetOrAdd(method, _ => []); - lock (handlers) + if (!TryGetValue(method, out var handlers)) { - handlers.Add(handler); + this[method] = handlers = []; + } + + handlers.Add(handler); + } + + /// Adds notification handlers as part of configuring the endpoint. + /// This method is not thread-safe and should only be used serially as part of configuring the instance. + public void AddRange(IEnumerable>> handlers) + { + foreach (var handler in handlers) + { + Add(handler.Key, handler.Value); } } } diff --git a/tests/ModelContextProtocol.TestServer/Program.cs b/tests/ModelContextProtocol.TestServer/Program.cs index 830a042a..ba6cca21 100644 --- a/tests/ModelContextProtocol.TestServer/Program.cs +++ b/tests/ModelContextProtocol.TestServer/Program.cs @@ -38,7 +38,6 @@ private static async Task Main(string[] args) McpServerOptions options = new() { - ServerInfo = new Implementation() { Name = "TestServer", Version = "1.0.0" }, Capabilities = new ServerCapabilities() { Tools = ConfigureTools(), diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index e4bd996b..43ae3207 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -25,7 +25,6 @@ private static void ConfigureSerilog(ILoggingBuilder loggingBuilder) private static void ConfigureOptions(McpServerOptions options) { - options.ServerInfo = new Implementation() { Name = "TestServer", Version = "1.0.0" }; options.Capabilities = new ServerCapabilities() { Tools = new(), diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs index ae58023f..aa2f773e 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs @@ -11,29 +11,24 @@ namespace ModelContextProtocol.Tests.Client; public class McpClientFactoryTests { - private readonly McpClientOptions _defaultOptions = new() - { - ClientInfo = new() { Name = "TestClient", Version = "1.0.0" } - }; - [Fact] public async Task CreateAsync_WithInvalidArgs_Throws() { - await Assert.ThrowsAsync("serverConfig", () => McpClientFactory.CreateAsync((McpServerConfig)null!, _defaultOptions, cancellationToken: TestContext.Current.CancellationToken)); + await Assert.ThrowsAsync("serverConfig", () => McpClientFactory.CreateAsync((McpServerConfig)null!, cancellationToken: TestContext.Current.CancellationToken)); await Assert.ThrowsAsync("serverConfig", () => McpClientFactory.CreateAsync(new McpServerConfig() { Name = "name", Id = "id", TransportType = "somethingunsupported", - }, _defaultOptions, cancellationToken: TestContext.Current.CancellationToken)); + }, cancellationToken: TestContext.Current.CancellationToken)); await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(new McpServerConfig() { Name = "name", Id = "id", TransportType = TransportTypes.StdIo, - }, _defaultOptions, (_, __) => null!, cancellationToken: TestContext.Current.CancellationToken)); + }, createTransportFunc: (_, __) => null!, cancellationToken: TestContext.Current.CancellationToken)); } [Fact] @@ -78,8 +73,7 @@ public async Task CreateAsync_WithValidStdioConfig_CreatesNewClient() // Act await using var client = await McpClientFactory.CreateAsync( serverConfig, - _defaultOptions, - (_, __) => new NopTransport(), + createTransportFunc: (_, __) => new NopTransport(), cancellationToken: TestContext.Current.CancellationToken); // Assert @@ -102,8 +96,7 @@ public async Task CreateAsync_WithNoTransportOptions_CreatesNewClient() // Act await using var client = await McpClientFactory.CreateAsync( serverConfig, - _defaultOptions, - (_, __) => new NopTransport(), + createTransportFunc: (_, __) => new NopTransport(), cancellationToken: TestContext.Current.CancellationToken); // Assert @@ -126,8 +119,7 @@ public async Task CreateAsync_WithValidSseConfig_CreatesNewClient() // Act await using var client = await McpClientFactory.CreateAsync( serverConfig, - _defaultOptions, - (_, __) => new NopTransport(), + createTransportFunc: (_, __) => new NopTransport(), cancellationToken: TestContext.Current.CancellationToken); // Assert @@ -157,8 +149,7 @@ public async Task CreateAsync_WithSse_CreatesCorrectTransportOptions() // Act await using var client = await McpClientFactory.CreateAsync( serverConfig, - _defaultOptions, - (_, __) => new NopTransport(), + createTransportFunc: (_, __) => new NopTransport(), cancellationToken: TestContext.Current.CancellationToken); // Assert @@ -186,7 +177,7 @@ public async Task McpFactory_WithInvalidTransportOptions_ThrowsFormatException(s }; // act & assert - await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(config, _defaultOptions, cancellationToken: TestContext.Current.CancellationToken)); + await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(config, cancellationToken: TestContext.Current.CancellationToken)); } [Theory] @@ -205,11 +196,6 @@ public async Task CreateAsync_WithCapabilitiesOptions(Type transportType) var clientOptions = new McpClientOptions { - ClientInfo = new Implementation - { - Name = "TestClient", - Version = "1.0.0.0" - }, Capabilities = new ClientCapabilities { Sampling = new SamplingCapability diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs index a77ae6b8..400caade 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs @@ -8,7 +8,6 @@ public class ClientIntegrationTestFixture { private ILoggerFactory? _loggerFactory; - public McpClientOptions DefaultOptions { get; } public McpServerConfig EverythingServerConfig { get; } public McpServerConfig TestServerConfig { get; } @@ -16,11 +15,6 @@ public class ClientIntegrationTestFixture public ClientIntegrationTestFixture() { - DefaultOptions = new() - { - ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" }, - }; - EverythingServerConfig = new() { Id = "everything", @@ -63,5 +57,5 @@ public Task CreateClientAsync(string clientId, McpClientOptions? cli "everything" => EverythingServerConfig, "test_server" => TestServerConfig, _ => throw new ArgumentException($"Unknown client ID: {clientId}") - }, clientOptions ?? DefaultOptions, loggerFactory: _loggerFactory); + }, clientOptions, loggerFactory: _loggerFactory); } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index 705a0779..6f0c80f1 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -5,10 +5,7 @@ using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Tests.Utils; using OpenAI; -using System.Text.Encodings.Web; using System.Text.Json; -using System.Text.Json.Serialization; -using System.Text.Json.Serialization.Metadata; namespace ModelContextProtocol.Tests; @@ -256,13 +253,22 @@ public async Task SubscribeResource_Stdio() // act TaskCompletionSource tcs = new(); - await using var client = await _fixture.CreateClientAsync(clientId); - client.AddNotificationHandler(NotificationMethods.ResourceUpdatedNotification, (notification) => + await using var client = await _fixture.CreateClientAsync(clientId, new() { - var notificationParams = JsonSerializer.Deserialize(notification.Params); - tcs.TrySetResult(true); - return Task.CompletedTask; + Capabilities = new() + { + NotificationHandlers = + [ + new(NotificationMethods.ResourceUpdatedNotification, notification => + { + var notificationParams = JsonSerializer.Deserialize(notification.Params); + tcs.TrySetResult(true); + return Task.CompletedTask; + }) + ] + } }); + await client.SubscribeToResourceAsync("test://static/resource/1", TestContext.Current.CancellationToken); await tcs.Task; @@ -277,12 +283,20 @@ public async Task UnsubscribeResource_Stdio() // act TaskCompletionSource receivedNotification = new(); - await using var client = await _fixture.CreateClientAsync(clientId); - client.AddNotificationHandler(NotificationMethods.ResourceUpdatedNotification, (notification) => + await using var client = await _fixture.CreateClientAsync(clientId, new() { - var notificationParams = JsonSerializer.Deserialize(notification.Params); - receivedNotification.TrySetResult(true); - return Task.CompletedTask; + Capabilities = new() + { + NotificationHandlers = + [ + new(NotificationMethods.ResourceUpdatedNotification, (notification) => + { + var notificationParams = JsonSerializer.Deserialize(notification.Params); + receivedNotification.TrySetResult(true); + return Task.CompletedTask; + }) + ] + } }); await client.SubscribeToResourceAsync("test://static/resource/1", TestContext.Current.CancellationToken); @@ -350,7 +364,6 @@ public async Task Sampling_Stdio(string clientId) int samplingHandlerCalls = 0; await using var client = await _fixture.CreateClientAsync(clientId, new() { - ClientInfo = new() { Name = "Sampling_Stdio", Version = "1.0.0" }, Capabilities = new() { Sampling = new() @@ -483,7 +496,6 @@ public async Task ListToolsAsync_UsingEverythingServer_ToolsAreProperlyCalled() // Get the MCP client and tools from it. await using var client = await McpClientFactory.CreateAsync( _fixture.EverythingServerConfig, - _fixture.DefaultOptions, cancellationToken: TestContext.Current.CancellationToken); var mappedTools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); @@ -516,7 +528,6 @@ public async Task SamplingViaChatClient_RequestResponseProperlyPropagated() .CreateSamplingHandler(); await using var client = await McpClientFactory.CreateAsync(_fixture.EverythingServerConfig, new() { - ClientInfo = new() { Name = nameof(SamplingViaChatClient_RequestResponseProperlyPropagated), Version = "1.0.0" }, Capabilities = new() { Sampling = new() @@ -543,15 +554,23 @@ public async Task SamplingViaChatClient_RequestResponseProperlyPropagated() public async Task SetLoggingLevel_ReceivesLoggingMessages(string clientId) { TaskCompletionSource receivedNotification = new(); - await using var client = await _fixture.CreateClientAsync(clientId); - client.AddNotificationHandler(NotificationMethods.LoggingMessageNotification, (notification) => + await using var client = await _fixture.CreateClientAsync(clientId, new() { - var loggingMessageNotificationParameters = JsonSerializer.Deserialize(notification.Params); - if (loggingMessageNotificationParameters is not null) + Capabilities = new() { - receivedNotification.TrySetResult(true); + NotificationHandlers = + [ + new(NotificationMethods.LoggingMessageNotification, (notification) => + { + var loggingMessageNotificationParameters = JsonSerializer.Deserialize(notification.Params); + if (loggingMessageNotificationParameters is not null) + { + receivedNotification.TrySetResult(true); + } + return Task.CompletedTask; + }) + ] } - return Task.CompletedTask; }); // act diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs index bbbbb05d..ed1be347 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs @@ -117,7 +117,7 @@ public async ValueTask DisposeAsync() Dispose(); } - private async Task CreateMcpClientForServer() + private async Task CreateMcpClientForServer(McpClientOptions? options = null) { return await McpClientFactory.CreateAsync( new McpServerConfig() @@ -126,6 +126,7 @@ private async Task CreateMcpClientForServer() Name = "TestServer", TransportType = "ignored", }, + options, createTransportFunc: (_, _) => new StreamClientTransport( serverInput: _clientToServerPipe.Writer.AsStream(), serverOutput: _serverToClientPipe.Reader.AsStream(), @@ -175,18 +176,23 @@ public async Task Can_List_And_Call_Registered_Prompts() [Fact] public async Task Can_Be_Notified_Of_Prompt_Changes() { - IMcpClient client = await CreateMcpClientForServer(); - - var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); - Assert.Equal(6, prompts.Count); - Channel listChanged = Channel.CreateUnbounded(); - client.AddNotificationHandler("notifications/prompts/list_changed", notification => + + IMcpClient client = await CreateMcpClientForServer(new() { - listChanged.Writer.TryWrite(notification); - return Task.CompletedTask; + Capabilities = new() + { + NotificationHandlers = [new("notifications/prompts/list_changed", notification => + { + listChanged.Writer.TryWrite(notification); + return Task.CompletedTask; + })], + }, }); + var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); + Assert.Equal(6, prompts.Count); + var notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken); Assert.False(notificationRead.IsCompleted); diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 79ae117f..827c7056 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -141,15 +141,16 @@ public async ValueTask DisposeAsync() Dispose(); } - private async Task CreateMcpClientForServer() + private async Task CreateMcpClientForServer(McpClientOptions? options = null) { return await McpClientFactory.CreateAsync( - new McpServerConfig() - { - Id = "TestServer", - Name = "TestServer", - TransportType = "ignored", - }, + new McpServerConfig() + { + Id = "TestServer", + Name = "TestServer", + TransportType = "ignored", + }, + options, createTransportFunc: (_, _) => new StreamClientTransport( serverInput: _clientToServerPipe.Writer.AsStream(), _serverToClientPipe.Reader.AsStream(), @@ -242,18 +243,23 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T [Fact] public async Task Can_Be_Notified_Of_Tool_Changes() { - IMcpClient client = await CreateMcpClientForServer(); - - var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); - Assert.Equal(16, tools.Count); - Channel listChanged = Channel.CreateUnbounded(); - client.AddNotificationHandler(NotificationMethods.ToolListChangedNotification, notification => + + IMcpClient client = await CreateMcpClientForServer(new() { - listChanged.Writer.TryWrite(notification); - return Task.CompletedTask; + Capabilities = new() + { + NotificationHandlers = [new(NotificationMethods.ToolListChangedNotification, notification => + { + listChanged.Writer.TryWrite(notification); + return Task.CompletedTask; + })], + }, }); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + Assert.Equal(16, tools.Count); + var notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken); Assert.False(notificationRead.IsCompleted); @@ -622,12 +628,17 @@ public async Task HandlesIProgressParameter() { ConcurrentQueue notifications = new(); - IMcpClient client = await CreateMcpClientForServer(); - client.AddNotificationHandler(NotificationMethods.ProgressNotification, notification => + IMcpClient client = await CreateMcpClientForServer(new() { - ProgressNotification pn = JsonSerializer.Deserialize(notification.Params)!; - notifications.Enqueue(pn); - return Task.CompletedTask; + Capabilities = new() + { + NotificationHandlers = [new(NotificationMethods.ProgressNotification, notification => + { + ProgressNotification pn = JsonSerializer.Deserialize(notification.Params)!; + notifications.Enqueue(pn); + return Task.CompletedTask; + })], + }, }); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs index 583ae274..fd93eff6 100644 --- a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs +++ b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs @@ -1,6 +1,5 @@ using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Transport; -using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using OpenTelemetry.Trace; using System.Diagnostics; @@ -49,7 +48,6 @@ private static async Task RunConnected(Func action await using (IMcpServer server = McpServerFactory.Create(serverTransport, new() { - ServerInfo = new Implementation { Name = "TestServer", Version = "1.0.0" }, Capabilities = new() { Tools = new() diff --git a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj index c8dc1ca9..16e6a891 100644 --- a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj +++ b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj @@ -11,6 +11,13 @@ ModelContextProtocol.Tests + + + true + + runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs index ae640ecc..034a30bd 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs @@ -1,5 +1,4 @@ -using ModelContextProtocol.Protocol.Types; -using ModelContextProtocol.Server; +using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Utils; namespace ModelContextProtocol.Tests.Server; @@ -13,7 +12,6 @@ public McpServerFactoryTests(ITestOutputHelper testOutputHelper) { _options = new McpServerOptions { - ServerInfo = new Implementation { Name = "TestServer", Version = "1.0" }, ProtocolVersion = "1.0", InitializationTimeout = TimeSpan.FromSeconds(30) }; diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 619dcfde..fb462748 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -24,7 +24,6 @@ private static McpServerOptions CreateOptions(ServerCapabilities? capabilities = { return new McpServerOptions { - ServerInfo = new Implementation { Name = "TestServer", Version = "1.0" }, ProtocolVersion = "2024", InitializationTimeout = TimeSpan.FromSeconds(30), Capabilities = capabilities, @@ -189,8 +188,8 @@ await Can_Handle_Requests( { var result = JsonSerializer.Deserialize(response); Assert.NotNull(result); - Assert.Equal("TestServer", result.ServerInfo.Name); - Assert.Equal("1.0", result.ServerInfo.Version); + Assert.Equal("ModelContextProtocol.Tests", result.ServerInfo.Name); + Assert.Equal("1.0.0.0", result.ServerInfo.Version); Assert.Equal("2024", result.ProtocolVersion); }); } @@ -634,8 +633,6 @@ public Task SendRequestAsync(JsonRpcRequest request, Cancellati public Implementation? ClientInfo => throw new NotImplementedException(); public McpServerOptions ServerOptions => throw new NotImplementedException(); public IServiceProvider? Services => throw new NotImplementedException(); - public void AddNotificationHandler(string method, Func handler) => - throw new NotImplementedException(); public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) => throw new NotImplementedException(); public Task RunAsync(CancellationToken cancellationToken = default) => @@ -649,13 +646,16 @@ public async Task NotifyProgress_Should_Be_Handled() var options = CreateOptions(); var notificationReceived = new TaskCompletionSource(); + options.Capabilities = new() + { + NotificationHandlers = [new(NotificationMethods.ProgressNotification, notification => + { + notificationReceived.TrySetResult(notification); + return Task.CompletedTask; + })], + }; var server = McpServerFactory.Create(transport, options, LoggerFactory); - server.AddNotificationHandler(NotificationMethods.ProgressNotification, notification => - { - notificationReceived.SetResult(notification); - return Task.CompletedTask; - }); Task serverTask = server.RunAsync(TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs index 0622e656..91e99898 100644 --- a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs @@ -29,11 +29,6 @@ public async Task ConnectAndReceiveMessage_InMemoryServer() await using InMemoryTestSseServer server = new(CreatePortNumber(), LoggerFactory.CreateLogger()); await server.StartAsync(); - var defaultOptions = new McpClientOptions - { - ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" } - }; - var defaultConfig = new McpServerConfig { Id = "test_server", @@ -46,7 +41,6 @@ public async Task ConnectAndReceiveMessage_InMemoryServer() // Act await using var client = await McpClientFactory.CreateAsync( defaultConfig, - defaultOptions, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); @@ -120,11 +114,6 @@ public async Task Sampling_Sse_EverythingServer() int samplingHandlerCalls = 0; var defaultOptions = new McpClientOptions { - ClientInfo = new() - { - Name = "IntegrationTestClient", - Version = "1.0.0" - }, Capabilities = new() { Sampling = new() @@ -175,11 +164,6 @@ public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventU server.UseFullUrlForEndpointEvent = true; await server.StartAsync(); - var defaultOptions = new McpClientOptions - { - ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" } - }; - var defaultConfig = new McpServerConfig { Id = "test_server", @@ -192,7 +176,6 @@ public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventU // Act await using var client = await McpClientFactory.CreateAsync( defaultConfig, - defaultOptions, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); @@ -213,12 +196,6 @@ public async Task ConnectAndReceiveNotification_InMemoryServer() await using InMemoryTestSseServer server = new(CreatePortNumber(), LoggerFactory.CreateLogger()); await server.StartAsync(); - - var defaultOptions = new McpClientOptions - { - ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" } - }; - var defaultConfig = new McpServerConfig { Id = "test_server", @@ -229,24 +206,28 @@ public async Task ConnectAndReceiveNotification_InMemoryServer() }; // Act + var receivedNotification = new TaskCompletionSource(); await using var client = await McpClientFactory.CreateAsync( defaultConfig, - defaultOptions, + new() + { + Capabilities = new() + { + NotificationHandlers = [new("test/notification", args => + { + var msg = args.Params?["message"]?.GetValue(); + receivedNotification.SetResult(msg); + + return Task.CompletedTask; + })], + }, + }, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); // Wait for SSE connection to be established await server.WaitForConnectionAsync(TimeSpan.FromSeconds(10)); - var receivedNotification = new TaskCompletionSource(); - client.AddNotificationHandler("test/notification", (args) => - { - var msg = args.Params?["message"]?.GetValue(); - receivedNotification.SetResult(msg); - - return Task.CompletedTask; - }); - // Act await server.SendTestNotificationAsync("Hello from server!"); diff --git a/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs b/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs index 238c7747..1720bc69 100644 --- a/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs @@ -1,5 +1,4 @@ -using ModelContextProtocol.Client; -using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Test.Utils; using ModelContextProtocol.Tests.Utils; using ModelContextProtocol.TestSseServer; @@ -32,11 +31,6 @@ public SseServerIntegrationTestFixture() _serverTask = Program.MainAsync([port.ToString()], new XunitLoggerProvider(_delegatingTestOutputHelper), _stopCts.Token); } - public static McpClientOptions CreateDefaultClientOptions() => new() - { - ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" }, - }; - public void Initialize(ITestOutputHelper output) { _delegatingTestOutputHelper.CurrentTestOutputHelper = output; diff --git a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs index a71766e2..eaead4e0 100644 --- a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs @@ -27,7 +27,7 @@ private Task GetClientAsync(McpClientOptions? options = null) { return McpClientFactory.CreateAsync( _fixture.DefaultConfig, - options ?? SseServerIntegrationTestFixture.CreateDefaultClientOptions(), + options, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); } @@ -220,8 +220,8 @@ public async Task Sampling_Sse_TestServer() // Set up the sampling handler int samplingHandlerCalls = 0; #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously - var options = SseServerIntegrationTestFixture.CreateDefaultClientOptions(); - options.Capabilities ??= new(); + McpClientOptions options = new(); + options.Capabilities = new(); options.Capabilities.Sampling ??= new(); options.Capabilities.Sampling.SamplingHandler = async (_, _, _) => { diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs index 33793555..7e79356e 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs @@ -1,7 +1,6 @@ using Microsoft.Extensions.Options; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; -using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Utils; using ModelContextProtocol.Utils.Json; @@ -21,11 +20,6 @@ public StdioServerTransportTests(ITestOutputHelper testOutputHelper) { _serverOptions = new McpServerOptions { - ServerInfo = new Implementation - { - Name = "Test Server", - Version = "1.0" - }, ProtocolVersion = "2.0", InitializationTimeout = TimeSpan.FromSeconds(10), ServerInstructions = "Test Instructions" @@ -49,8 +43,6 @@ public void Constructor_Throws_For_Null_Options() Assert.Throws("serverOptions", () => new StdioServerTransport((IOptions)null!)); Assert.Throws("serverOptions", () => new StdioServerTransport((McpServerOptions)null!)); - Assert.Throws("serverOptions.ServerInfo", () => new StdioServerTransport(new McpServerOptions() { ServerInfo = null! })); - Assert.Throws("serverOptions.ServerInfo.Name", () => new StdioServerTransport(new McpServerOptions() { ServerInfo = new() { Name = null!, Version = "" } })); } [Fact]