From 8cdc2824f1e080d9f71117d1477322e81add5b65 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Mon, 7 Apr 2025 12:58:33 -0400 Subject: [PATCH 1/2] Enable servers to log to the client via ILogger --- .../Client/McpClientExtensions.cs | 11 +++ .../McpEndpointExtensions.cs | 2 +- .../Protocol/Types/EmptyResult.cs | 7 +- src/ModelContextProtocol/Server/IMcpServer.cs | 3 + src/ModelContextProtocol/Server/McpServer.cs | 60 +++++++++++--- .../Server/McpServerExtensions.cs | 55 ++++++++++++- .../Client/McpClientExtensionsTests.cs | 82 ++++++++++++++++++- .../Server/McpServerTests.cs | 1 + 8 files changed, 203 insertions(+), 18 deletions(-) diff --git a/src/ModelContextProtocol/Client/McpClientExtensions.cs b/src/ModelContextProtocol/Client/McpClientExtensions.cs index 2fc4eb5a..e7d9a86a 100644 --- a/src/ModelContextProtocol/Client/McpClientExtensions.cs +++ b/src/ModelContextProtocol/Client/McpClientExtensions.cs @@ -1,6 +1,8 @@ using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Server; using ModelContextProtocol.Utils; using ModelContextProtocol.Utils.Json; using System.Runtime.CompilerServices; @@ -631,6 +633,15 @@ public static Task SetLoggingLevel(this IMcpClient client, LoggingLevel level, C cancellationToken: cancellationToken); } + /// + /// Configures the minimum logging level for the server. + /// + /// The client. + /// The minimum log level of messages to be generated. + /// The to monitor for cancellation requests. The default is . + public static Task SetLoggingLevel(this IMcpClient client, LogLevel level, CancellationToken cancellationToken = default) => + SetLoggingLevel(client, McpServer.ToLoggingLevel(level), cancellationToken); + /// Convers a dictionary with values to a dictionary with values. private static IReadOnlyDictionary? ToArgumentsDictionary( IReadOnlyDictionary? arguments, JsonSerializerOptions options) diff --git a/src/ModelContextProtocol/McpEndpointExtensions.cs b/src/ModelContextProtocol/McpEndpointExtensions.cs index b099019e..3a7a7721 100644 --- a/src/ModelContextProtocol/McpEndpointExtensions.cs +++ b/src/ModelContextProtocol/McpEndpointExtensions.cs @@ -83,7 +83,7 @@ internal static async Task SendRequestAsync( } /// - /// Sends a notification to the server with parameters. + /// Sends a notification to the server with no parameters. /// /// The client. /// The notification method name. diff --git a/src/ModelContextProtocol/Protocol/Types/EmptyResult.cs b/src/ModelContextProtocol/Protocol/Types/EmptyResult.cs index 3dc1a8de..1651c42c 100644 --- a/src/ModelContextProtocol/Protocol/Types/EmptyResult.cs +++ b/src/ModelContextProtocol/Protocol/Types/EmptyResult.cs @@ -1,4 +1,6 @@ -namespace ModelContextProtocol.Protocol.Types; +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol.Types; /// /// An empty result object. @@ -6,5 +8,6 @@ /// public class EmptyResult { - + [JsonIgnore] + internal static Task CompletedTask { get; } = Task.FromResult(new EmptyResult()); } \ No newline at end of file diff --git a/src/ModelContextProtocol/Server/IMcpServer.cs b/src/ModelContextProtocol/Server/IMcpServer.cs index 19b3967a..cd8df6fa 100644 --- a/src/ModelContextProtocol/Server/IMcpServer.cs +++ b/src/ModelContextProtocol/Server/IMcpServer.cs @@ -25,6 +25,9 @@ public interface IMcpServer : IMcpEndpoint /// IServiceProvider? Services { get; } + /// Gets the last logging level set by the client, or if it's never been set. + LoggingLevel? LoggingLevel { get; } + /// /// Runs the server, listening for and handling client requests. /// diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index 81214f0b..96345fef 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -5,7 +5,7 @@ using ModelContextProtocol.Shared; using ModelContextProtocol.Utils; using ModelContextProtocol.Utils.Json; -using System.Diagnostics; +using System.Runtime.CompilerServices; namespace ModelContextProtocol.Server; @@ -26,6 +26,13 @@ internal sealed class McpServer : McpEndpoint, IMcpServer private string _endpointName; private int _started; + /// Holds a boxed value for the server. + /// + /// Initialized to non-null the first time SetLevel is used. This is stored as a strong box + /// rather than a nullable to be able to manipulate it atomically. + /// + private StrongBox? _loggingLevel; + /// /// Creates a new instance of . /// @@ -105,6 +112,9 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? /// public override string EndpointName => _endpointName; + /// + public LoggingLevel? LoggingLevel => _loggingLevel?.Value; + /// public async Task RunAsync(CancellationToken cancellationToken = default) { @@ -441,20 +451,48 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) private void SetSetLoggingLevelHandler(McpServerOptions options) { - if (options.Capabilities?.Logging is not { } loggingCapability) - { - return; - } - - if (loggingCapability.SetLoggingLevelHandler is not { } setLoggingLevelHandler) - { - throw new McpException("Logging capability was enabled, but SetLoggingLevelHandler was not specified."); - } + // We don't require that the handler be provided, as we always store the provided + // log level to the server. + var setLoggingLevelHandler = options.Capabilities?.Logging?.SetLoggingLevelHandler; RequestHandlers.Set( RequestMethods.LoggingSetLevel, - (request, cancellationToken) => setLoggingLevelHandler(new(this, request), cancellationToken), + (request, cancellationToken) => + { + // Store the provided level. + if (request is not null) + { + if (_loggingLevel is null) + { + Interlocked.CompareExchange(ref _loggingLevel, new(request.Level), null); + } + + _loggingLevel.Value = request.Level; + } + + // If a handler was provided, now delegate to it. + if (setLoggingLevelHandler is not null) + { + return setLoggingLevelHandler(new(this, request), cancellationToken); + } + + // Otherwise, consider it handled. + return EmptyResult.CompletedTask; + }, McpJsonUtilities.JsonContext.Default.SetLevelRequestParams, McpJsonUtilities.JsonContext.Default.EmptyResult); } + + /// Maps a to a . + internal static LoggingLevel ToLoggingLevel(LogLevel level) => + level switch + { + LogLevel.Trace => Protocol.Types.LoggingLevel.Debug, + LogLevel.Debug => Protocol.Types.LoggingLevel.Debug, + LogLevel.Warning => Protocol.Types.LoggingLevel.Warning, + LogLevel.Error => Protocol.Types.LoggingLevel.Error, + LogLevel.Critical => Protocol.Types.LoggingLevel.Critical, + LogLevel.Information => Protocol.Types.LoggingLevel.Info, + _ => Protocol.Types.LoggingLevel.Debug, + }; } \ No newline at end of file diff --git a/src/ModelContextProtocol/Server/McpServerExtensions.cs b/src/ModelContextProtocol/Server/McpServerExtensions.cs index 06ec596c..3772808e 100644 --- a/src/ModelContextProtocol/Server/McpServerExtensions.cs +++ b/src/ModelContextProtocol/Server/McpServerExtensions.cs @@ -1,10 +1,12 @@ using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Utils; using ModelContextProtocol.Utils.Json; using System.Runtime.CompilerServices; using System.Text; +using System.Text.Json; namespace ModelContextProtocol.Server; @@ -28,7 +30,7 @@ public static Task RequestSamplingAsync( return server.SendRequestAsync( RequestMethods.SamplingCreateMessage, - request, + request, McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, McpJsonUtilities.JsonContext.Default.CreateMessageResult, cancellationToken: cancellationToken); @@ -46,7 +48,7 @@ public static Task RequestSamplingAsync( /// is . /// The client does not support sampling. public static async Task RequestSamplingAsync( - this IMcpServer server, + this IMcpServer server, IEnumerable messages, ChatOptions? options = default, CancellationToken cancellationToken = default) { Throw.IfNull(server); @@ -153,6 +155,16 @@ public static IChatClient AsSamplingChatClient(this IMcpServer server) return new SamplingChatClient(server); } + /// Gets an on which logged messages will be sent as notifications to the client. + /// The server to wrap as an . + /// An that can be used to log to the client.. + public static ILogger AsClientLogger(this IMcpServer server) + { + Throw.IfNull(server); + + return new ClientLogger(server); + } + /// /// Requests the client to list the roots it exposes. /// @@ -210,4 +222,43 @@ async IAsyncEnumerable IChatClient.GetStreamingResponseAsync /// void IDisposable.Dispose() { } // nop } + + /// + /// Provides an implementation for sending logging message notifications + /// to the client for logged messages. + /// + private sealed class ClientLogger(IMcpServer server) : ILogger + { + /// + public IDisposable? BeginScope(TState state) where TState : notnull => + null; + + /// + public bool IsEnabled(LogLevel logLevel) => + server?.LoggingLevel is { } loggingLevel && + McpServer.ToLoggingLevel(logLevel) >= loggingLevel; + + /// + public void Log(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) + { + if (!IsEnabled(logLevel)) + { + return; + } + + Throw.IfNull(formatter); + + Log(logLevel, formatter(state, exception)); + + void Log(LogLevel logLevel, string message) + { + _ = server.SendNotificationAsync(NotificationMethods.LoggingMessageNotification, new LoggingMessageNotificationParams() + { + Level = McpServer.ToLoggingLevel(logLevel), + Data = JsonSerializer.SerializeToElement(message, McpJsonUtilities.JsonContext.Default.String), + Logger = eventId.Name, + }); + } + } + } } diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs index 07ea3d18..38e92ab9 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs @@ -1,5 +1,6 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; @@ -10,6 +11,7 @@ using System.IO.Pipelines; using System.Text.Json; using System.Text.Json.Serialization.Metadata; +using System.Threading.Channels; namespace ModelContextProtocol.Tests.Client; @@ -19,6 +21,7 @@ public class McpClientExtensionsTests : LoggedTest private readonly Pipe _serverToClientPipe = new(); private readonly ServiceProvider _serviceProvider; private readonly CancellationTokenSource _cts; + private readonly IMcpServer _server; private readonly Task _serverTask; public McpClientExtensionsTests(ITestOutputHelper outputHelper) @@ -36,9 +39,9 @@ public McpClientExtensionsTests(ITestOutputHelper outputHelper) sc.AddSingleton(McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)](string i) => $"{i} Result", new() { Name = "ValuesSetViaOptions", Destructive = true, OpenWorld = false, ReadOnly = true })); _serviceProvider = sc.BuildServiceProvider(); - var server = _serviceProvider.GetRequiredService(); + _server = _serviceProvider.GetRequiredService(); _cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); - _serverTask = server.RunAsync(cancellationToken: _cts.Token); + _serverTask = _server.RunAsync(cancellationToken: _cts.Token); } [Theory] @@ -374,4 +377,79 @@ public async Task WithDescription_ChangesToolDescription() Assert.Equal("ToolWithNewDescription", redescribedTool.Description); Assert.Equal(originalDescription, tool?.Description); } + + [Fact] + public async Task AsClientLogger_MessagesSentToClient() + { + IMcpClient client = await CreateMcpClientForServer(); + + ILogger logger = _server.AsClientLogger(); + Assert.Null(logger.BeginScope("")); + + Assert.Null(_server.LoggingLevel); + Assert.False(logger.IsEnabled(LogLevel.Trace)); + Assert.False(logger.IsEnabled(LogLevel.Debug)); + Assert.False(logger.IsEnabled(LogLevel.Information)); + Assert.False(logger.IsEnabled(LogLevel.Warning)); + Assert.False(logger.IsEnabled(LogLevel.Error)); + Assert.False(logger.IsEnabled(LogLevel.Critical)); + + await client.SetLoggingLevel(LoggingLevel.Info, TestContext.Current.CancellationToken); + + DateTime start = DateTime.UtcNow; + while (_server.LoggingLevel is null) + { + await Task.Delay(1, TestContext.Current.CancellationToken); + Assert.True(DateTime.UtcNow - start < TimeSpan.FromSeconds(10), "Timed out waiting for logging level to be set"); + } + + Assert.Equal(LoggingLevel.Info, _server.LoggingLevel); + Assert.False(logger.IsEnabled(LogLevel.Trace)); + Assert.False(logger.IsEnabled(LogLevel.Debug)); + Assert.True(logger.IsEnabled(LogLevel.Information)); + Assert.True(logger.IsEnabled(LogLevel.Warning)); + Assert.True(logger.IsEnabled(LogLevel.Error)); + Assert.True(logger.IsEnabled(LogLevel.Critical)); + + List data = []; + var channel = Channel.CreateUnbounded(); + + await using (client.RegisterNotificationHandler(NotificationMethods.LoggingMessageNotification, + (notification, cancellationToken) => + { + Assert.True(channel.Writer.TryWrite(JsonSerializer.Deserialize(notification.Params))); + return Task.CompletedTask; + })) + { + logger.LogTrace("Trace {Message}", "message"); + logger.LogDebug("Debug {Message}", "message"); + logger.LogInformation("Information {Message}", "message"); + logger.LogWarning("Warning {Message}", "message"); + logger.LogError("Error {Message}", "message"); + logger.LogCritical("Critical {Message}", "message"); + + for (int i = 0; i < 4; i++) + { + var m = await channel.Reader.ReadAsync(TestContext.Current.CancellationToken); + Assert.NotNull(m); + Assert.NotNull(m.Data); + + string? s = JsonSerializer.Deserialize(m.Data.Value); + Assert.NotNull(s); + data.Add(s); + } + + channel.Writer.Complete(); + } + + Assert.False(await channel.Reader.WaitToReadAsync(TestContext.Current.CancellationToken)); + Assert.Equal( + [ + "Critical message", + "Error message", + "Information message", + "Warning message", + ], + data.OrderBy(s => s)); + } } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index ef149b77..a3024702 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -618,6 +618,7 @@ public Task SendRequestAsync(JsonRpcRequest request, Cancellati public Implementation? ClientInfo => throw new NotImplementedException(); public McpServerOptions ServerOptions => throw new NotImplementedException(); public IServiceProvider? Services => throw new NotImplementedException(); + public LoggingLevel? LoggingLevel => throw new NotImplementedException(); public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) => throw new NotImplementedException(); public Task RunAsync(CancellationToken cancellationToken = default) => From 8da94b107792fcbd7375d47831ee16c09c266057 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Mon, 7 Apr 2025 17:55:55 -0400 Subject: [PATCH 2/2] Address feedback --- src/ModelContextProtocol/Server/McpServer.cs | 4 +- .../Server/McpServerExtensions.cs | 62 ++++++++++++------- .../Client/McpClientExtensionsTests.cs | 33 ++++++++-- 3 files changed, 69 insertions(+), 30 deletions(-) diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index 96345fef..284ba77f 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -489,10 +489,10 @@ internal static LoggingLevel ToLoggingLevel(LogLevel level) => { LogLevel.Trace => Protocol.Types.LoggingLevel.Debug, LogLevel.Debug => Protocol.Types.LoggingLevel.Debug, + LogLevel.Information => Protocol.Types.LoggingLevel.Info, LogLevel.Warning => Protocol.Types.LoggingLevel.Warning, LogLevel.Error => Protocol.Types.LoggingLevel.Error, LogLevel.Critical => Protocol.Types.LoggingLevel.Critical, - LogLevel.Information => Protocol.Types.LoggingLevel.Info, - _ => Protocol.Types.LoggingLevel.Debug, + _ => Protocol.Types.LoggingLevel.Emergency, }; } \ No newline at end of file diff --git a/src/ModelContextProtocol/Server/McpServerExtensions.cs b/src/ModelContextProtocol/Server/McpServerExtensions.cs index 3772808e..9f8c723d 100644 --- a/src/ModelContextProtocol/Server/McpServerExtensions.cs +++ b/src/ModelContextProtocol/Server/McpServerExtensions.cs @@ -158,11 +158,11 @@ public static IChatClient AsSamplingChatClient(this IMcpServer server) /// Gets an on which logged messages will be sent as notifications to the client. /// The server to wrap as an . /// An that can be used to log to the client.. - public static ILogger AsClientLogger(this IMcpServer server) + public static ILoggerProvider AsClientLoggerProvider(this IMcpServer server) { Throw.IfNull(server); - return new ClientLogger(server); + return new ClientLoggerProvider(server); } /// @@ -224,40 +224,54 @@ void IDisposable.Dispose() { } // nop } /// - /// Provides an implementation for sending logging message notifications - /// to the client for logged messages. + /// Provides an implementation for creating loggers + /// that send logging message notifications to the client for logged messages. /// - private sealed class ClientLogger(IMcpServer server) : ILogger + private sealed class ClientLoggerProvider(IMcpServer server) : ILoggerProvider { /// - public IDisposable? BeginScope(TState state) where TState : notnull => - null; + public ILogger CreateLogger(string categoryName) + { + Throw.IfNull(categoryName); - /// - public bool IsEnabled(LogLevel logLevel) => - server?.LoggingLevel is { } loggingLevel && - McpServer.ToLoggingLevel(logLevel) >= loggingLevel; + return new ClientLogger(server, categoryName); + } /// - public void Log(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) + void IDisposable.Dispose() { } + + private sealed class ClientLogger(IMcpServer server, string categoryName) : ILogger { - if (!IsEnabled(logLevel)) + /// + public IDisposable? BeginScope(TState state) where TState : notnull => + null; + + /// + public bool IsEnabled(LogLevel logLevel) => + server?.LoggingLevel is { } loggingLevel && + McpServer.ToLoggingLevel(logLevel) >= loggingLevel; + + /// + public void Log(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) { - return; - } + if (!IsEnabled(logLevel)) + { + return; + } - Throw.IfNull(formatter); + Throw.IfNull(formatter); - Log(logLevel, formatter(state, exception)); + Log(logLevel, formatter(state, exception)); - void Log(LogLevel logLevel, string message) - { - _ = server.SendNotificationAsync(NotificationMethods.LoggingMessageNotification, new LoggingMessageNotificationParams() + void Log(LogLevel logLevel, string message) { - Level = McpServer.ToLoggingLevel(logLevel), - Data = JsonSerializer.SerializeToElement(message, McpJsonUtilities.JsonContext.Default.String), - Logger = eventId.Name, - }); + _ = server.SendNotificationAsync(NotificationMethods.LoggingMessageNotification, new LoggingMessageNotificationParams() + { + Level = McpServer.ToLoggingLevel(logLevel), + Data = JsonSerializer.SerializeToElement(message, McpJsonUtilities.JsonContext.Default.String), + Logger = categoryName, + }); + } } } } diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs index 38e92ab9..aa538fa0 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs @@ -379,11 +379,16 @@ public async Task WithDescription_ChangesToolDescription() } [Fact] - public async Task AsClientLogger_MessagesSentToClient() + public async Task AsClientLoggerProvider_MessagesSentToClient() { IMcpClient client = await CreateMcpClientForServer(); - ILogger logger = _server.AsClientLogger(); + ILoggerProvider loggerProvider = _server.AsClientLoggerProvider(); + Assert.Throws("categoryName", () => loggerProvider.CreateLogger(null!)); + + ILogger logger = loggerProvider.CreateLogger("TestLogger"); + Assert.NotNull(logger); + Assert.Null(logger.BeginScope("")); Assert.Null(_server.LoggingLevel); @@ -433,9 +438,29 @@ public async Task AsClientLogger_MessagesSentToClient() var m = await channel.Reader.ReadAsync(TestContext.Current.CancellationToken); Assert.NotNull(m); Assert.NotNull(m.Data); - - string? s = JsonSerializer.Deserialize(m.Data.Value); + + Assert.Equal("TestLogger", m.Logger); + + string ? s = JsonSerializer.Deserialize(m.Data.Value); Assert.NotNull(s); + + if (s.Contains("Information")) + { + Assert.Equal(LoggingLevel.Info, m.Level); + } + else if (s.Contains("Warning")) + { + Assert.Equal(LoggingLevel.Warning, m.Level); + } + else if (s.Contains("Error")) + { + Assert.Equal(LoggingLevel.Error, m.Level); + } + else if (s.Contains("Critical")) + { + Assert.Equal(LoggingLevel.Critical, m.Level); + } + data.Add(s); }