Skip to content

Move notification handler registrations to capabilities #207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion samples/QuickstartWeatherServer/Tools/WeatherTools.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using ModelContextProtocol;
using ModelContextProtocol.Server;
using System.ComponentModel;
using System.Net.Http.Json;
using System.Text.Json;

namespace QuickstartWeatherServer.Tools;
Expand Down
90 changes: 53 additions & 37 deletions src/ModelContextProtocol/Client/McpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ namespace ModelContextProtocol.Client;
/// <inheritdoc/>
internal sealed class McpClient : McpEndpoint, IMcpClient
{
private static Implementation DefaultImplementation { get; } = new()
{
Name = DefaultAssemblyName.Name ?? nameof(McpClient),
Version = DefaultAssemblyName.Version?.ToString() ?? "1.0.0",
};

private readonly IClientTransport _clientTransport;
private readonly McpClientOptions _options;

Expand All @@ -29,43 +35,53 @@ internal sealed class McpClient : McpEndpoint, IMcpClient
/// <param name="options">Options for the client, defining protocol version and capabilities.</param>
/// <param name="serverConfig">The server configuration.</param>
/// <param name="loggerFactory">The logger factory.</param>
public McpClient(IClientTransport clientTransport, McpClientOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory)
public McpClient(IClientTransport clientTransport, McpClientOptions? options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory)
: base(loggerFactory)
{
options ??= new();

_clientTransport = clientTransport;
_options = options;

EndpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})";

if (options.Capabilities?.Sampling is { } samplingCapability)
if (options.Capabilities is { } capabilities)
{
if (samplingCapability.SamplingHandler is not { } samplingHandler)
if (capabilities.NotificationHandlers is { } notificationHandlers)
{
throw new InvalidOperationException($"Sampling capability was set but it did not provide a handler.");
NotificationHandlers.AddRange(notificationHandlers);
}

SetRequestHandler(
RequestMethods.SamplingCreateMessage,
(request, cancellationToken) => samplingHandler(
request,
request?.Meta?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance,
cancellationToken),
McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams,
McpJsonUtilities.JsonContext.Default.CreateMessageResult);
}

if (options.Capabilities?.Roots is { } rootsCapability)
{
if (rootsCapability.RootsHandler is not { } rootsHandler)
if (capabilities.Sampling is { } samplingCapability)
{
throw new InvalidOperationException($"Roots capability was set but it did not provide a handler.");
if (samplingCapability.SamplingHandler is not { } samplingHandler)
{
throw new InvalidOperationException($"Sampling capability was set but it did not provide a handler.");
}

RequestHandlers.Set(
RequestMethods.SamplingCreateMessage,
(request, cancellationToken) => samplingHandler(
request,
request?.Meta?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance,
cancellationToken),
McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams,
McpJsonUtilities.JsonContext.Default.CreateMessageResult);
}

SetRequestHandler(
RequestMethods.RootsList,
rootsHandler,
McpJsonUtilities.JsonContext.Default.ListRootsRequestParams,
McpJsonUtilities.JsonContext.Default.ListRootsResult);
if (capabilities.Roots is { } rootsCapability)
{
if (rootsCapability.RootsHandler is not { } rootsHandler)
{
throw new InvalidOperationException($"Roots capability was set but it did not provide a handler.");
}

RequestHandlers.Set(
RequestMethods.RootsList,
rootsHandler,
McpJsonUtilities.JsonContext.Default.ListRootsRequestParams,
McpJsonUtilities.JsonContext.Default.ListRootsResult);
}
}
}

Expand Down Expand Up @@ -96,20 +112,20 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)
using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
initializationCts.CancelAfter(_options.InitializationTimeout);

try
{
// Send initialize request
var initializeResponse = await this.SendRequestAsync(
RequestMethods.Initialize,
new InitializeRequestParams
{
ProtocolVersion = _options.ProtocolVersion,
Capabilities = _options.Capabilities ?? new ClientCapabilities(),
ClientInfo = _options.ClientInfo
},
McpJsonUtilities.JsonContext.Default.InitializeRequestParams,
McpJsonUtilities.JsonContext.Default.InitializeResult,
cancellationToken: initializationCts.Token).ConfigureAwait(false);
try
{
// Send initialize request
var initializeResponse = await this.SendRequestAsync(
RequestMethods.Initialize,
new InitializeRequestParams
{
ProtocolVersion = _options.ProtocolVersion,
Capabilities = _options.Capabilities ?? new ClientCapabilities(),
ClientInfo = _options.ClientInfo ?? DefaultImplementation,
},
McpJsonUtilities.JsonContext.Default.InitializeRequestParams,
McpJsonUtilities.JsonContext.Default.InitializeResult,
cancellationToken: initializationCts.Token).ConfigureAwait(false);

// Store server information
_logger.ServerCapabilitiesReceived(EndpointName,
Expand Down
19 changes: 0 additions & 19 deletions src/ModelContextProtocol/Client/McpClientFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,12 @@
using ModelContextProtocol.Utils;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using System.Reflection;

namespace ModelContextProtocol.Client;

/// <summary>Provides factory methods for creating MCP clients.</summary>
public static class McpClientFactory
{
/// <summary>Default client options to use when none are supplied.</summary>
private static readonly McpClientOptions s_defaultClientOptions = CreateDefaultClientOptions();

/// <summary>Creates default client options to use when no options are supplied.</summary>
private static McpClientOptions CreateDefaultClientOptions()
{
var asmName = (Assembly.GetEntryAssembly() ?? Assembly.GetExecutingAssembly()).GetName();
return new()
{
ClientInfo = new()
{
Name = asmName.Name ?? "McpClient",
Version = asmName.Version?.ToString() ?? "1.0.0",
},
};
}

/// <summary>Creates an <see cref="IMcpClient"/>, connecting it to the specified server.</summary>
/// <param name="serverConfig">Configuration for the target server to which the client should connect.</param>
/// <param name="clientOptions">
Expand All @@ -52,7 +34,6 @@ public static async Task<IMcpClient> CreateAsync(
{
Throw.IfNull(serverConfig);

clientOptions ??= s_defaultClientOptions;
createTransportFunc ??= CreateTransport;

string endpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})";
Expand Down
2 changes: 1 addition & 1 deletion src/ModelContextProtocol/Client/McpClientOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public class McpClientOptions
/// <summary>
/// Information about this client implementation.
/// </summary>
public required Implementation ClientInfo { get; set; }
public Implementation? ClientInfo { get; set; }

/// <summary>
/// Client capabilities to advertise to the server.
Expand Down
1 change: 0 additions & 1 deletion src/ModelContextProtocol/Client/McpClientTool.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Utils.Json;
using ModelContextProtocol.Utils;
using Microsoft.Extensions.AI;
using System.Text.Json;

Expand Down
15 changes: 1 addition & 14 deletions src/ModelContextProtocol/Configuration/McpServerOptionsSetup.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System.Reflection;
using ModelContextProtocol.Server;
using ModelContextProtocol.Server;
using Microsoft.Extensions.Options;
using ModelContextProtocol.Utils;

Expand All @@ -25,18 +24,6 @@ public void Configure(McpServerOptions options)
{
Throw.IfNull(options);

// Configure the option's server information based on the current process,
// if it otherwise lacks server information.
if (options.ServerInfo is not { } serverInfo)
{
var assemblyName = Assembly.GetEntryAssembly()?.GetName();
options.ServerInfo = new()
{
Name = assemblyName?.Name ?? "McpServer",
Version = assemblyName?.Version?.ToString() ?? "1.0.0",
};
}

// Collect all of the provided tools into a tools collection. If the options already has
// a collection, add to it, otherwise create a new one. We want to maintain the identity
// of an existing collection in case someone has provided their own derived type, wants
Expand Down
16 changes: 0 additions & 16 deletions src/ModelContextProtocol/IMcpEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,4 @@ public interface IMcpEndpoint : IAsyncDisposable
/// <param name="message">The message.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default);

/// <summary>
/// Adds a handler for server notifications of a specific method.
/// </summary>
/// <param name="method">The notification method to handle.</param>
/// <param name="handler">The async handler function to process notifications.</param>
/// <remarks>
/// <para>
/// Each method may have multiple handlers. Adding a handler for a method that already has one
/// will not replace the existing handler.
/// </para>
/// <para>
/// <see cref="NotificationMethods"> provides constants for common notification methods.</see>
/// </para>
/// </remarks>
void AddNotificationHandler(string method, Func<JsonRpcNotification, Task> handler);
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public McpTransportException(string message)
/// </summary>
/// <param name="message">The message that describes the error.</param>
/// <param name="innerException">The exception that is the cause of the current exception.</param>
public McpTransportException(string message, Exception innerException)
public McpTransportException(string message, Exception? innerException)
: base(message, innerException)
{
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,22 @@ public StdioClientSessionTransport(StdioClientTransportOptions options, Process
/// <inheritdoc/>
public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
{
if (_process.HasExited)
Exception? processException = null;
bool hasExited = false;
try
{
hasExited = _process.HasExited;
}
catch (Exception e)
{
processException = e;
hasExited = true;
}

if (hasExited)
{
Logger.TransportNotConnected(EndpointName);
throw new McpTransportException("Transport is not connected");
throw new McpTransportException("Transport is not connected", processException);
}

await base.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
Expand All @@ -33,7 +45,7 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio
/// <inheritdoc/>
protected override ValueTask CleanupAsync(CancellationToken cancellationToken)
{
StdioClientTransport.DisposeProcess(_process, processStarted: true, Logger, _options.ShutdownTimeout, EndpointName);
StdioClientTransport.DisposeProcess(_process, processRunning: true, Logger, _options.ShutdownTimeout, EndpointName);

return base.CleanupAsync(cancellationToken);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,25 @@ public async Task<ITransport> ConnectAsync(CancellationToken cancellationToken =
}

internal static void DisposeProcess(
Process? process, bool processStarted, ILogger logger, TimeSpan shutdownTimeout, string endpointName)
Process? process, bool processRunning, ILogger logger, TimeSpan shutdownTimeout, string endpointName)
{
if (process is not null)
{
if (processRunning)
{
try
{
processRunning = !process.HasExited;
}
catch
{
processRunning = false;
}
}

try
{
if (processStarted && !process.HasExited)
if (processRunning)
{
// Wait for the process to exit.
// Kill the while process tree because the process may spawn child processes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ public StdioServerTransport(string serverName, ILoggerFactory? loggerFactory = n
private static string GetServerName(McpServerOptions serverOptions)
{
Throw.IfNull(serverOptions);
Throw.IfNull(serverOptions.ServerInfo);
Throw.IfNull(serverOptions.ServerInfo.Name);

return serverOptions.ServerInfo.Name;
return serverOptions.ServerInfo?.Name ?? McpServer.DefaultImplementation.Name;
}
}
11 changes: 10 additions & 1 deletion src/ModelContextProtocol/Protocol/Types/Capabilities.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ModelContextProtocol.Server;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Server;
using System.Text.Json.Serialization;

namespace ModelContextProtocol.Protocol.Types;
Expand Down Expand Up @@ -26,6 +27,14 @@ public class ClientCapabilities
/// </summary>
[JsonPropertyName("sampling")]
public SamplingCapability? Sampling { get; set; }

/// <summary>Gets or sets notification handlers to register with the client.</summary>
/// <remarks>
/// When constructed, the client will enumerate these handlers, which may contain multiple handlers per key.
/// The client will not re-enumerate the sequence.
/// </remarks>
[JsonIgnore]
public IEnumerable<KeyValuePair<string, Func<JsonRpcNotification, Task>>>? NotificationHandlers { get; set; }
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
using ModelContextProtocol.Protocol.Messages;

namespace ModelContextProtocol.Protocol.Types;
namespace ModelContextProtocol.Protocol.Types;

/// <summary>
/// A request from the server to get a list of root URIs from the client.
Expand Down
11 changes: 10 additions & 1 deletion src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Text.Json.Serialization;
using ModelContextProtocol.Protocol.Messages;
using System.Text.Json.Serialization;

namespace ModelContextProtocol.Protocol.Types;

Expand Down Expand Up @@ -37,4 +38,12 @@ public class ServerCapabilities
/// </summary>
[JsonPropertyName("tools")]
public ToolsCapability? Tools { get; set; }

/// <summary>Gets or sets notification handlers to register with the server.</summary>
/// <remarks>
/// When constructed, the server will enumerate these handlers, which may contain multiple handlers per key.
/// The server will not re-enumerate the sequence.
/// </remarks>
[JsonIgnore]
public IEnumerable<KeyValuePair<string, Func<JsonRpcNotification, Task>>>? NotificationHandlers { get; set; }
}
Loading