Skip to content

Enable servers to log to the clients via ILogger #229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/ModelContextProtocol/Client/McpClientExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Logging;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Server;
using ModelContextProtocol.Utils;
using ModelContextProtocol.Utils.Json;
using System.Runtime.CompilerServices;
Expand Down Expand Up @@ -631,6 +633,15 @@ public static Task SetLoggingLevel(this IMcpClient client, LoggingLevel level, C
cancellationToken: cancellationToken);
}

/// <summary>
/// Configures the minimum logging level for the server.
/// </summary>
/// <param name="client">The client.</param>
/// <param name="level">The minimum log level of messages to be generated.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
public static Task SetLoggingLevel(this IMcpClient client, LogLevel level, CancellationToken cancellationToken = default) =>
SetLoggingLevel(client, McpServer.ToLoggingLevel(level), cancellationToken);

/// <summary>Convers a dictionary with <see cref="object"/> values to a dictionary with <see cref="JsonElement"/> values.</summary>
private static IReadOnlyDictionary<string, JsonElement>? ToArgumentsDictionary(
IReadOnlyDictionary<string, object?>? arguments, JsonSerializerOptions options)
Expand Down
2 changes: 1 addition & 1 deletion src/ModelContextProtocol/McpEndpointExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ internal static async Task<TResult> SendRequestAsync<TParameters, TResult>(
}

/// <summary>
/// Sends a notification to the server with parameters.
/// Sends a notification to the server with no parameters.
/// </summary>
/// <param name="client">The client.</param>
/// <param name="method">The notification method name.</param>
Expand Down
7 changes: 5 additions & 2 deletions src/ModelContextProtocol/Protocol/Types/EmptyResult.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
namespace ModelContextProtocol.Protocol.Types;
using System.Text.Json.Serialization;

namespace ModelContextProtocol.Protocol.Types;

/// <summary>
/// An empty result object.
/// <see href="https://github.com/modelcontextprotocol/specification/blob/main/schema/">See the schema for details</see>
/// </summary>
public class EmptyResult
{

[JsonIgnore]
internal static Task<EmptyResult> CompletedTask { get; } = Task.FromResult(new EmptyResult());
}
3 changes: 3 additions & 0 deletions src/ModelContextProtocol/Server/IMcpServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ public interface IMcpServer : IMcpEndpoint
/// </summary>
IServiceProvider? Services { get; }

/// <summary>Gets the last logging level set by the client, or <see langword="null"/> if it's never been set.</summary>
LoggingLevel? LoggingLevel { get; }

/// <summary>
/// Runs the server, listening for and handling client requests.
/// </summary>
Expand Down
60 changes: 49 additions & 11 deletions src/ModelContextProtocol/Server/McpServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
using ModelContextProtocol.Shared;
using ModelContextProtocol.Utils;
using ModelContextProtocol.Utils.Json;
using System.Diagnostics;
using System.Runtime.CompilerServices;

namespace ModelContextProtocol.Server;

Expand All @@ -26,6 +26,13 @@ internal sealed class McpServer : McpEndpoint, IMcpServer
private string _endpointName;
private int _started;

/// <summary>Holds a boxed <see cref="LoggingLevel"/> value for the server.</summary>
/// <remarks>
/// Initialized to non-null the first time SetLevel is used. This is stored as a strong box
/// rather than a nullable to be able to manipulate it atomically.
/// </remarks>
private StrongBox<LoggingLevel>? _loggingLevel;

/// <summary>
/// Creates a new instance of <see cref="McpServer"/>.
/// </summary>
Expand Down Expand Up @@ -105,6 +112,9 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory?
/// <inheritdoc />
public override string EndpointName => _endpointName;

/// <inheritdoc />
public LoggingLevel? LoggingLevel => _loggingLevel?.Value;

/// <inheritdoc />
public async Task RunAsync(CancellationToken cancellationToken = default)
{
Expand Down Expand Up @@ -441,20 +451,48 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false)

private void SetSetLoggingLevelHandler(McpServerOptions options)
{
if (options.Capabilities?.Logging is not { } loggingCapability)
{
return;
}

if (loggingCapability.SetLoggingLevelHandler is not { } setLoggingLevelHandler)
{
throw new McpException("Logging capability was enabled, but SetLoggingLevelHandler was not specified.");
}
// We don't require that the handler be provided, as we always store the provided
// log level to the server.
var setLoggingLevelHandler = options.Capabilities?.Logging?.SetLoggingLevelHandler;

RequestHandlers.Set(
RequestMethods.LoggingSetLevel,
(request, cancellationToken) => setLoggingLevelHandler(new(this, request), cancellationToken),
(request, cancellationToken) =>
{
// Store the provided level.
if (request is not null)
{
if (_loggingLevel is null)
{
Interlocked.CompareExchange(ref _loggingLevel, new(request.Level), null);
}

_loggingLevel.Value = request.Level;
}

// If a handler was provided, now delegate to it.
if (setLoggingLevelHandler is not null)
{
return setLoggingLevelHandler(new(this, request), cancellationToken);
}

// Otherwise, consider it handled.
return EmptyResult.CompletedTask;
},
McpJsonUtilities.JsonContext.Default.SetLevelRequestParams,
McpJsonUtilities.JsonContext.Default.EmptyResult);
}

/// <summary>Maps a <see cref="LogLevel"/> to a <see cref="LoggingLevel"/>.</summary>
internal static LoggingLevel ToLoggingLevel(LogLevel level) =>
level switch
{
LogLevel.Trace => Protocol.Types.LoggingLevel.Debug,
LogLevel.Debug => Protocol.Types.LoggingLevel.Debug,
LogLevel.Information => Protocol.Types.LoggingLevel.Info,
LogLevel.Warning => Protocol.Types.LoggingLevel.Warning,
LogLevel.Error => Protocol.Types.LoggingLevel.Error,
LogLevel.Critical => Protocol.Types.LoggingLevel.Critical,
_ => Protocol.Types.LoggingLevel.Emergency,
};
}
69 changes: 67 additions & 2 deletions src/ModelContextProtocol/Server/McpServerExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Logging;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Utils;
using ModelContextProtocol.Utils.Json;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;

namespace ModelContextProtocol.Server;

Expand All @@ -28,7 +30,7 @@ public static Task<CreateMessageResult> RequestSamplingAsync(

return server.SendRequestAsync(
RequestMethods.SamplingCreateMessage,
request,
request,
McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams,
McpJsonUtilities.JsonContext.Default.CreateMessageResult,
cancellationToken: cancellationToken);
Expand All @@ -46,7 +48,7 @@ public static Task<CreateMessageResult> RequestSamplingAsync(
/// <exception cref="ArgumentNullException"><paramref name="messages"/> is <see langword="null"/>.</exception>
/// <exception cref="ArgumentException">The client does not support sampling.</exception>
public static async Task<ChatResponse> RequestSamplingAsync(
this IMcpServer server,
this IMcpServer server,
IEnumerable<ChatMessage> messages, ChatOptions? options = default, CancellationToken cancellationToken = default)
{
Throw.IfNull(server);
Expand Down Expand Up @@ -153,6 +155,16 @@ public static IChatClient AsSamplingChatClient(this IMcpServer server)
return new SamplingChatClient(server);
}

/// <summary>Gets an <see cref="ILogger"/> on which logged messages will be sent as notifications to the client.</summary>
/// <param name="server">The server to wrap as an <see cref="ILogger"/>.</param>
/// <returns>An <see cref="ILogger"/> that can be used to log to the client..</returns>
public static ILoggerProvider AsClientLoggerProvider(this IMcpServer server)
{
Throw.IfNull(server);

return new ClientLoggerProvider(server);
}

/// <summary>
/// Requests the client to list the roots it exposes.
/// </summary>
Expand Down Expand Up @@ -210,4 +222,57 @@ async IAsyncEnumerable<ChatResponseUpdate> IChatClient.GetStreamingResponseAsync
/// <inheritdoc/>
void IDisposable.Dispose() { } // nop
}

/// <summary>
/// Provides an <see cref="ILoggerProvider"/> implementation for creating loggers
/// that send logging message notifications to the client for logged messages.
/// </summary>
private sealed class ClientLoggerProvider(IMcpServer server) : ILoggerProvider
{
/// <inheritdoc />
public ILogger CreateLogger(string categoryName)
{
Throw.IfNull(categoryName);

return new ClientLogger(server, categoryName);
}

/// <inheritdoc />
void IDisposable.Dispose() { }

private sealed class ClientLogger(IMcpServer server, string categoryName) : ILogger
{
/// <inheritdoc />
public IDisposable? BeginScope<TState>(TState state) where TState : notnull =>
null;

/// <inheritdoc />
public bool IsEnabled(LogLevel logLevel) =>
server?.LoggingLevel is { } loggingLevel &&
McpServer.ToLoggingLevel(logLevel) >= loggingLevel;

/// <inheritdoc />
public void Log<TState>(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func<TState, Exception?, string> formatter)
{
if (!IsEnabled(logLevel))
{
return;
}

Throw.IfNull(formatter);

Log(logLevel, formatter(state, exception));

void Log(LogLevel logLevel, string message)
{
_ = server.SendNotificationAsync(NotificationMethods.LoggingMessageNotification, new LoggingMessageNotificationParams()
{
Level = McpServer.ToLoggingLevel(logLevel),
Data = JsonSerializer.SerializeToElement(message, McpJsonUtilities.JsonContext.Default.String),
Logger = categoryName,
});
}
}
}
}
}
Loading