From 1f51437a330e23fbd4ad1f2af49cda8290f5e425 Mon Sep 17 00:00:00 2001 From: "Wenzel, Toni" Date: Sat, 29 Mar 2025 17:29:37 +0100 Subject: [PATCH] Add InMemoryTransport --- .../McpServerBuilderExtensions.Transports.cs | 58 ++++++++-- .../Transport/InMemoryServerTransport.cs | 102 ++++++++++++++++++ .../Server/McpServerTests.cs | 26 ++--- .../Utils/TestServerTransport.cs | 74 +++---------- 4 files changed, 180 insertions(+), 80 deletions(-) create mode 100644 src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs diff --git a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs index 1f357d32..2659066d 100644 --- a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs +++ b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs @@ -1,8 +1,10 @@ -using ModelContextProtocol.Configuration; +using System.Diagnostics.CodeAnalysis; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Configuration; using ModelContextProtocol.Hosting; +using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Utils; -using Microsoft.Extensions.DependencyInjection; namespace ModelContextProtocol; @@ -16,23 +18,67 @@ public static partial class McpServerBuilderExtensions /// /// The builder instance. public static IMcpServerBuilder WithStdioServerTransport(this IMcpServerBuilder builder) + { + return builder.WithServerTransport(); + } + + /// + /// Adds a server transport that uses SSE via a HttpListener for communication. + /// + /// The builder instance. + public static IMcpServerBuilder WithHttpListenerSseServerTransport(this IMcpServerBuilder builder) + { + return builder.WithServerTransport(); + } + + /// + /// Adds a server transport for in-memory communication. + /// + /// The builder instance. + public static IMcpServerBuilder WithInMemoryServerTransport(this IMcpServerBuilder builder) + { + return builder.WithServerTransport(); + } + + /// + /// Adds a server transport for in-memory communication. + /// + /// The builder instance. + /// Delegate to handle messages. + public static IMcpServerBuilder WithInMemoryServerTransport(this IMcpServerBuilder builder, Func> handleMessageDelegate) + { + var transport = new InMemoryServerTransport + { + HandleMessage = handleMessageDelegate + }; + + return builder.WithServerTransport(transport); + } + + /// + /// Adds a server transport for communication. + /// + /// The type of the server transport to use. + /// The builder instance. + public static IMcpServerBuilder WithServerTransport<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] TTransport>(this IMcpServerBuilder builder) where TTransport : class, IServerTransport { Throw.IfNull(builder); - builder.Services.AddSingleton(); + builder.Services.AddSingleton(); builder.Services.AddHostedService(); return builder; } /// - /// Adds a server transport that uses SSE via a HttpListener for communication. + /// Adds a server transport for communication. /// + /// Instance of the server transport. /// The builder instance. - public static IMcpServerBuilder WithHttpListenerSseServerTransport(this IMcpServerBuilder builder) + public static IMcpServerBuilder WithServerTransport(this IMcpServerBuilder builder, IServerTransport serverTransport) { Throw.IfNull(builder); - builder.Services.AddSingleton(); + builder.Services.AddSingleton(serverTransport); builder.Services.AddHostedService(); return builder; } diff --git a/src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs new file mode 100644 index 00000000..ba60558a --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs @@ -0,0 +1,102 @@ +using System.Threading.Channels; +using ModelContextProtocol.Protocol.Messages; + +namespace ModelContextProtocol.Protocol.Transport; + +/// +/// InMemory server transport for special scenarios or testing. +/// +public class InMemoryServerTransport : IServerTransport +{ + private readonly Channel _messageChannel; + private bool _isStarted; + + /// + public bool IsConnected => _isStarted; + + /// + public ChannelReader MessageReader => _messageChannel; + + /// + /// Delegate to handle messages before sending them. + /// + public Func>? HandleMessage { get; set; } + + /// + /// Initializes a new instance of the class. + /// + public InMemoryServerTransport() + { + _messageChannel = Channel.CreateUnbounded(new UnboundedChannelOptions + { + SingleReader = true, + SingleWriter = true, + }); + + // default message handler + HandleMessage = (m, _) => Task.FromResult(CreateResponseMessage(m)); + } + + /// +#if NET8_0_OR_GREATER + public ValueTask DisposeAsync() => ValueTask.CompletedTask; +#else + public ValueTask DisposeAsync() => new ValueTask(Task.CompletedTask); +#endif + + /// + public virtual async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + { + IJsonRpcMessage? response = message; + + if (HandleMessage != null) + response = await HandleMessage(message, cancellationToken); + + if (response != null) + await WriteMessageAsync(response, cancellationToken); + } + + /// + public virtual Task StartListeningAsync(CancellationToken cancellationToken = default) + { + _isStarted = true; + return Task.CompletedTask; + } + + /// + /// Writes a message to the channel. + /// + protected virtual async Task WriteMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + { + await _messageChannel.Writer.WriteAsync(message, cancellationToken); + } + + /// + /// Creates a response message for the given request. + /// + /// + /// + protected virtual IJsonRpcMessage? CreateResponseMessage(IJsonRpcMessage message) + { + if (message is JsonRpcRequest request) + { + return new JsonRpcResponse + { + Id = request.Id, + Result = CreateMessageResult(request) + }; + } + + return message; + } + + /// + /// Creates a result object for the given request. + /// + /// + /// + protected virtual object? CreateMessageResult(JsonRpcRequest request) + { + return null; + } +} diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 3f4dd1d8..82b8e883 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -1,4 +1,5 @@ -using Microsoft.Extensions.AI; +using System.Reflection; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using ModelContextProtocol.Client; @@ -8,7 +9,6 @@ using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Utils; using Moq; -using System.Reflection; namespace ModelContextProtocol.Tests.Server; @@ -132,9 +132,9 @@ public async Task StartAsync_Sets_Initialized_After_Transport_Responses_Initiali // Send initialized notification await transport.SendMessageAsync(new JsonRpcNotification - { - Method = "notifications/initialized" - }, TestContext.Current.CancellationToken); + { + Method = "notifications/initialized" + }, TestContext.Current.CancellationToken); await Task.Delay(50, TestContext.Current.CancellationToken); @@ -389,7 +389,7 @@ await Can_Handle_Requests( }, ListResourcesHandler = (request, ct) => throw new NotImplementedException(), } - }, + }, method: "resources/read", configureOptions: null, assertResult: response => @@ -450,7 +450,7 @@ public async Task Can_Handle_List_Prompts_Requests_Throws_Exception_If_No_Handle public async Task Can_Handle_Get_Prompts_Requests() { await Can_Handle_Requests( - new ServerCapabilities + new ServerCapabilities { Prompts = new() { @@ -479,7 +479,7 @@ public async Task Can_Handle_Get_Prompts_Requests_Throws_Exception_If_No_Handler public async Task Can_Handle_List_Tools_Requests() { await Can_Handle_Requests( - new ServerCapabilities + new ServerCapabilities { Tools = new() { @@ -528,7 +528,7 @@ await Can_Handle_Requests( }, ListToolsHandler = (request, ct) => throw new NotImplementedException(), } - }, + }, method: "tools/call", configureOptions: null, assertResult: response => @@ -559,10 +559,12 @@ private async Task Can_Handle_Requests(ServerCapabilities? serverCapabilities, s var receivedMessage = new TaskCompletionSource(); - transport.OnMessageSent = (message) => + transport.HandleMessage = (message, _) => { if (message is JsonRpcResponse response && response.Id.AsNumber == 55) receivedMessage.SetResult(response); + + return Task.FromResult((IJsonRpcMessage?)message); }; await transport.SendMessageAsync( @@ -582,7 +584,7 @@ await transport.SendMessageAsync( private async Task Throws_Exception_If_No_Handler_Assigned(ServerCapabilities serverCapabilities, string method, string expectedError) { - await using var transport = new TestServerTransport(); + await using var transport = new InMemoryServerTransport(); var options = CreateOptions(serverCapabilities); Assert.Throws(() => McpServerFactory.Create(transport, options, LoggerFactory, _serviceProvider)); @@ -680,7 +682,7 @@ public Task SendRequestAsync(JsonRpcRequest request, CancellationToken can public Implementation? ClientInfo => throw new NotImplementedException(); public McpServerOptions ServerOptions => throw new NotImplementedException(); public IServiceProvider? Services => throw new NotImplementedException(); - public void AddNotificationHandler(string method, Func handler) => + public void AddNotificationHandler(string method, Func handler) => throw new NotImplementedException(); public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) => throw new NotImplementedException(); diff --git a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs index f38d933e..17167332 100644 --- a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs +++ b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs @@ -1,83 +1,33 @@ -using System.Threading.Channels; -using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; namespace ModelContextProtocol.Tests.Utils; -public class TestServerTransport : IServerTransport +public class TestServerTransport : InMemoryServerTransport { - private readonly Channel _messageChannel; - private bool _isStarted; - - public bool IsConnected => _isStarted; - - public ChannelReader MessageReader => _messageChannel; - public List SentMessages { get; } = []; - public Action? OnMessageSent { get; set; } - - public TestServerTransport() - { - _messageChannel = Channel.CreateUnbounded(new UnboundedChannelOptions - { - SingleReader = true, - SingleWriter = true, - }); - } - - public ValueTask DisposeAsync() => ValueTask.CompletedTask; - - public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + public override Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) { SentMessages.Add(message); - if (message is JsonRpcRequest request) - { - if (request.Method == "roots/list") - await ListRoots(request, cancellationToken); - else if (request.Method == "sampling/createMessage") - await Sampling(request, cancellationToken); - else - await WriteMessageAsync(request, cancellationToken); - } - else if (message is JsonRpcNotification notification) - { - await WriteMessageAsync(notification, cancellationToken); - } - - OnMessageSent?.Invoke(message); - } - public Task StartListeningAsync(CancellationToken cancellationToken = default) - { - _isStarted = true; - return Task.CompletedTask; + return base.SendMessageAsync(message, cancellationToken); } - private async Task ListRoots(JsonRpcRequest request, CancellationToken cancellationToken) + protected override object? CreateMessageResult(JsonRpcRequest request) { - await WriteMessageAsync(new JsonRpcResponse + if (request.Method == "roots/list") { - Id = request.Id, - Result = new ModelContextProtocol.Protocol.Types.ListRootsResult + return new ModelContextProtocol.Protocol.Types.ListRootsResult { Roots = [] - } - }, cancellationToken); - } + }; + } - private async Task Sampling(JsonRpcRequest request, CancellationToken cancellationToken) - { - await WriteMessageAsync(new JsonRpcResponse - { - Id = request.Id, - Result = new CreateMessageResult { Content = new(), Model = "model", Role = "role" } - }, cancellationToken); - } + if (request.Method == "sampling/createMessage") + return new CreateMessageResult { Content = new(), Model = "model", Role = "role" }; - protected async Task WriteMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) - { - await _messageChannel.Writer.WriteAsync(message, cancellationToken); + return base.CreateMessageResult(request); } }