From 005a16df1ed76611ddcaf69a6789ac1a5305c05b Mon Sep 17 00:00:00 2001 From: Diego Colombo Date: Tue, 25 Mar 2025 12:52:23 +0000 Subject: [PATCH 01/13] updated git ignore to ignore lut and ncrunch files --- .gitignore | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.gitignore b/.gitignore index 1dc1beb8..efd47c77 100644 --- a/.gitignore +++ b/.gitignore @@ -73,3 +73,9 @@ _NCrunch_* nCrunchTemp_* *.orig + +*.ncrunchsolution + +*.lutconfig + +.NCrunch_ModelContextProtocol/ From c26df4ff6e8f00d52fd773e93e5695edbc9eb368 Mon Sep 17 00:00:00 2001 From: Diego Colombo Date: Tue, 25 Mar 2025 12:52:46 +0000 Subject: [PATCH 02/13] additional ingore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index efd47c77..18f24147 100644 --- a/.gitignore +++ b/.gitignore @@ -79,3 +79,5 @@ nCrunchTemp_* *.lutconfig .NCrunch_ModelContextProtocol/ + +*.ncrunchproject From 354e56f263dab012b39b60925c6d809685f9fd30 Mon Sep 17 00:00:00 2001 From: Diego Colombo Date: Tue, 25 Mar 2025 17:07:01 +0000 Subject: [PATCH 03/13] Create memory transport for in process mcp --- Directory.Packages.props | 3 +- .../ModelContextProtocol.csproj | 3 +- .../Protocol/Transport/InMemoryTransport.cs | 402 ++++++++++++++++++ .../Transport/InMemoryTransportTests.cs | 136 ++++++ 4 files changed, 541 insertions(+), 3 deletions(-) create mode 100644 src/ModelContextProtocol/Protocol/Transport/InMemoryTransport.cs create mode 100644 tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs diff --git a/Directory.Packages.props b/Directory.Packages.props index 852f63a2..68236407 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -11,15 +11,14 @@ + - - diff --git a/src/ModelContextProtocol/ModelContextProtocol.csproj b/src/ModelContextProtocol/ModelContextProtocol.csproj index 0e677692..b6b40438 100644 --- a/src/ModelContextProtocol/ModelContextProtocol.csproj +++ b/src/ModelContextProtocol/ModelContextProtocol.csproj @@ -14,8 +14,9 @@ - + + diff --git a/src/ModelContextProtocol/Protocol/Transport/InMemoryTransport.cs b/src/ModelContextProtocol/Protocol/Transport/InMemoryTransport.cs new file mode 100644 index 00000000..fba5427f --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Transport/InMemoryTransport.cs @@ -0,0 +1,402 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; + +using ModelContextProtocol.Configuration; +using ModelContextProtocol.Logging; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Server; + +using System.Diagnostics.CodeAnalysis; +using System.Threading.Channels; + + +namespace ModelContextProtocol.Protocol.Transport; + +/// +/// Provides an in memory implementation of the MCP transport protocol over shared memory. +/// This transport enables efficient in-process MCP functionality and easier test authoring. +/// +/// +/// +/// The InMemoryTransport allows both client and server to communicate within the same process, +/// which is particularly useful for testing scenarios and embedded applications. +/// +/// +/// This implementation requires dynamic code access for tool registration and might not work in Native AOT. +/// +/// +[RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)] +public sealed class InMemoryTransport : TransportBase, IServerTransport, IClientTransport +{ + private const string RequiresUnreferencedCodeMessage = "This method requires dynamic lookup of method metadata and might not work in Native AOT."; + + private readonly string _endpointName = "InMemoryTransport"; + private readonly ILogger _logger; + private readonly Type[] _toolTypes; + private readonly Channel _sharedChannel; + private readonly SemaphoreSlim _connectionLock; + private CancellationTokenSource? _cancellationTokenSource; + private Task? _processingTask; + private Task? _serverTask; + private IMcpServer? _server; + private volatile bool _disposed; + + /// + /// Initializes a new instance of the class. + /// + /// Optional logger factory for logging transport operations. + /// The tool types to be registered with the transport. + /// Thrown when no tool types are provided. + private InMemoryTransport( + ILoggerFactory? loggerFactory, + IEnumerable toolTypes) + : base(loggerFactory) + { + var arrayOfToolTypes = toolTypes as Type[] ?? toolTypes.ToArray(); + if (arrayOfToolTypes.Length == 0) + { + throw new ArgumentException( + "At least one tool type must be provided", + nameof(toolTypes)); + } + + _toolTypes = arrayOfToolTypes; + _logger = loggerFactory?.CreateLogger() + ?? NullLogger.Instance; + _connectionLock = new SemaphoreSlim(1, 1); + + var options = new UnboundedChannelOptions + { + SingleReader = true, + SingleWriter = true, + AllowSynchronousContinuations = true + }; + _sharedChannel = Channel.CreateUnbounded(options); + } + + /// + /// Creates a new instance of the class. + /// + /// Optional logger factory for logging transport operations. + /// One or more tool types to be registered with the transport. + /// A new instance of . + /// Thrown when no tool types are provided. + /// Thrown when tool types is null. + [RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)] + public static InMemoryTransport Create( + ILoggerFactory? loggerFactory, + params IEnumerable[] toolTypes) + { + if (toolTypes is null) + { + throw new ArgumentNullException(nameof(toolTypes)); + } + + if (toolTypes.Length == 0) + { + throw new ArgumentException( + "At least one tool type enumerable must be provided", + nameof(toolTypes)); + } + + return new InMemoryTransport( + loggerFactory, + toolTypes.SelectMany(x => x)); + } + + /// + public override async Task SendMessageAsync( + IJsonRpcMessage message, + CancellationToken cancellationToken = default) + { + // During disposal, allow final message to complete if we're still connected + if (_disposed && !IsConnected) + { + throw new ObjectDisposedException(nameof(InMemoryTransport)); + } + + if (!IsConnected) + { + _logger.TransportNotConnected(_endpointName); + throw new McpTransportException("Transport is not connected"); + } + + string id = "(no id)"; + if (message is IJsonRpcMessageWithId messageWithId) + { + id = messageWithId.Id.ToString(); + } + + try + { + _logger.TransportSendingMessage(_endpointName, id); + + // Only write to shared channel, let HandleMessageReceivedAsync handle base channel write + await _sharedChannel.Writer.WriteAsync(message, cancellationToken); + + // Wait briefly for the message to be processed if we're disposing + if (_disposed) + { + await Task.Delay(50, cancellationToken); + } + + _logger.TransportSentMessage(_endpointName, id); + } + catch (Exception ex) + { + _logger.TransportSendFailed(_endpointName, id, ex); + throw new McpTransportException("Failed to send message", ex); + } + } + + /// + Task IClientTransport.ConnectAsync(CancellationToken cancellationToken) + { + ThrowIfDisposed(); +#pragma warning disable IL2026 + return ConnectInternalAsync(cancellationToken); +#pragma warning restore IL2026 + } + + /// + Task IServerTransport.StartListeningAsync(CancellationToken cancellationToken) + { + ThrowIfDisposed(); +#pragma warning disable IL2026 + return ConnectInternalAsync(cancellationToken); +#pragma warning restore IL2026 + } + + /// + public override async ValueTask DisposeAsync() + { + await CleanupAsync(CancellationToken.None).ConfigureAwait(false); + } + + /// + /// Connects the transport and initializes the server. + /// + /// A token to observe for cancellation requests. + /// A task representing the asynchronous operation. + /// Thrown when the transport is already connected or when connection fails. + [RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)] + private async Task ConnectInternalAsync( + CancellationToken cancellationToken) + { + await _connectionLock.WaitAsync(cancellationToken); + try + { + if (IsConnected) + { + _logger.TransportAlreadyConnected(_endpointName); + throw new McpTransportException("Transport is already connected"); + } + + _logger.TransportConnecting(_endpointName); + + // Create a service collection and builder to set up the in-memory server + var services = new ServiceCollection(); + services.AddSingleton(this); + services.AddSingleton(this); + + // Configure server options + var serverOptions = new McpServerOptions + { + ServerInfo = new Protocol.Types.Implementation { Name = "InMemoryServer", Version = "1.0" }, + ProtocolVersion = "2024", + Capabilities = new Protocol.Types.ServerCapabilities() + }; + + services.AddOptions().Configure(options => + { + options.ServerInfo = serverOptions.ServerInfo; + options.ProtocolVersion = serverOptions.ProtocolVersion; + options.Capabilities = serverOptions.Capabilities; + }); + + // Create a server builder to register the tools + var builder = new DefaultMcpServerBuilder(services); + + // Register the provided tool types + if (_toolTypes.Length > 0) + { + _logger.LogDebug("Registering {Count} tool types", _toolTypes.Length); + builder.WithTools(_toolTypes); + } + + try + { + // Create IServiceProvider instance manually + var serviceProvider = services.BuildServiceProvider(); + + // Create and initialize the server using a logger factory + var loggerFactory = NullLoggerFactory.Instance; + _server = McpServerFactory.Create(this, serverOptions, loggerFactory, serviceProvider); + + // Create cancellation source to manage all tasks + _cancellationTokenSource = new CancellationTokenSource(); + + // Start the server as fire-and-forget (don't await) + _logger.LogDebug("Starting server (fire-and-forget)"); + _serverTask = _server.StartAsync(_cancellationTokenSource.Token); + + // Start a background task to process messages from the shared channel + _processingTask = Task.Run(async () => + { + try + { + _logger.LogDebug("Starting message processing task"); + await foreach (var message in _sharedChannel.Reader.ReadAllAsync(_cancellationTokenSource.Token)) + { + _logger.LogTrace("Transport reading message from channel: {Message}", message); + await HandleMessageReceivedAsync(message, _cancellationTokenSource.Token); + } + } + catch (OperationCanceledException) when (_cancellationTokenSource.Token.IsCancellationRequested) + { + // Expected when cancellation is requested + _logger.TransportReadMessagesCancelled(_endpointName); + } + catch (Exception ex) + { + _logger.TransportReadMessagesFailed(_endpointName, ex); + } + }, _cancellationTokenSource.Token); + + // Short delay to allow background tasks to start + await Task.Delay(10, cancellationToken); + + // Only set connected if initialization succeeded + SetConnected(true); + _logger.LogDebug("Transport connected for {EndpointName}", _endpointName); + } + catch (Exception ex) + { + _logger.TransportConnectFailed(_endpointName, ex); + await CleanupAsync(cancellationToken); + throw new McpTransportException("Failed to connect transport", ex); + } + } + finally + { + _connectionLock.Release(); + } + } + + /// + /// Cleans up resources used by the transport. + /// + /// A token to observe for cancellation requests. + /// A task representing the asynchronous operation. + private async Task CleanupAsync(CancellationToken cancellationToken = default) + { + if (_disposed) + { + return; + } + + _logger.TransportCleaningUp(_endpointName); + + try + { + // Mark as disposed to prevent new operations but keep connection alive + _disposed = true; + + // First wait for the processing task to complete + if (_processingTask != null) + { + try + { + _logger.TransportWaitingForReadTask(_endpointName); + await Task.WhenAny(_processingTask, Task.Delay(500, cancellationToken)); + } + catch (Exception ex) + { + _logger.TransportCleanupReadTaskFailed(_endpointName, ex); + } + } + + // Complete the shared channel + _sharedChannel.Writer.Complete(); + + // Then cancel the server and tasks + if (_cancellationTokenSource != null) + { + await _cancellationTokenSource.CancelAsync(); + _cancellationTokenSource.Dispose(); + _cancellationTokenSource = null; + _processingTask = null; + _serverTask = null; + } + + // Dispose server with timeout + if (_server != null) + { + try + { + await _server.DisposeAsync().ConfigureAwait(false); + } + catch (OperationCanceledException) + { + // Timeout is acceptable + } + catch (Exception ex) + { + _logger.TransportShutdownFailed(_endpointName, ex); + } + _server = null; + } + + // Dispose connection lock + _connectionLock.Dispose(); + } + finally + { + // Set connected to false last + SetConnected(false); + _logger.TransportCleanedUp(_endpointName); + } + } + + /// + /// Throws an if the transport has been disposed. + /// + /// Thrown when the transport has been disposed. + private void ThrowIfDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(InMemoryTransport)); + } + } + + /// + /// Processes a message received from the shared channel. + /// + /// The message to process. + /// A token to observe for cancellation requests. + /// A task representing the asynchronous operation. + private async Task HandleMessageReceivedAsync(IJsonRpcMessage message, CancellationToken cancellationToken) + { + string id = "(no id)"; + if (message is IJsonRpcMessageWithId messageWithId) + { + id = messageWithId.Id.ToString(); + } + + try + { + _logger.TransportReceivedMessageParsed(_endpointName, id); + + // Write the message to the base class message channel + await base.WriteMessageAsync(message, cancellationToken); + + _logger.TransportMessageWritten(_endpointName, id); + } + catch (Exception ex) + { + _logger.LogError(ex, "Transport message write failed for {EndpointName} with ID {MessageId}", _endpointName, id); + } + } +} diff --git a/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs new file mode 100644 index 00000000..765918d8 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs @@ -0,0 +1,136 @@ +using Microsoft.Extensions.Logging.Abstractions; + +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Server; + +namespace ModelContextProtocol.Tests.Transport; + +public class InMemoryTransportTests +{ + private readonly Type[] _toolTypes = [typeof(TestTool)]; + + [Fact] + public async Task Constructor_Should_Initialize_With_Valid_Parameters() + { + // Act + await using var transport = InMemoryTransport.Create(NullLoggerFactory.Instance, _toolTypes); + + // Assert + Assert.NotNull(transport); + } + + [Fact] + public void Constructor_Throws_For_Null_Parameters() + { + Assert.Throws("toolTypes", + () => InMemoryTransport.Create(NullLoggerFactory.Instance, Array.Empty())); + + Assert.Throws("toolTypes", + () => InMemoryTransport.Create(NullLoggerFactory.Instance, null!)); + } + + [Fact] + public async Task ConnectAsync_Should_Set_Connected_State() + { + // Arrange + var transport = InMemoryTransport.Create(NullLoggerFactory.Instance, _toolTypes); + + // Act + var clientTransport = (IClientTransport)transport; + await clientTransport.ConnectAsync(TestContext.Current.CancellationToken); + + // Assert + Assert.True(transport.IsConnected); + + await transport.DisposeAsync(); + } + + [Fact] + public async Task StartListeningAsync_Should_Set_Connected_State() + { + // Arrange + await using var transport = InMemoryTransport.Create(NullLoggerFactory.Instance, _toolTypes); + + // Act + var serverTransport = (IServerTransport)transport; + await serverTransport.StartListeningAsync(TestContext.Current.CancellationToken); + + // Assert + Assert.True(transport.IsConnected); + } + + [Theory] + [InlineData("Hello, World!")] + [InlineData("上下文伺服器")] + [InlineData("🔍 🚀 👍")] + public async Task SendMessageAsync_Should_Preserve_Characters(string messageText) + { + // Arrange + var transport = InMemoryTransport.Create(NullLoggerFactory.Instance, _toolTypes); + + IServerTransport serverTransport = transport; + await serverTransport.StartListeningAsync(TestContext.Current.CancellationToken); + + // Ensure transport is fully initialized + await Task.Delay(100, TestContext.Current.CancellationToken); + + var chineseMessage = new JsonRpcRequest + { + Method = "test", + Id = RequestId.FromNumber(44), + Params = new Dictionary + { + ["text"] = messageText + } + }; + + // Act & Assert - Chinese + await transport.SendMessageAsync(chineseMessage, TestContext.Current.CancellationToken); + + Assert.True(transport.MessageReader.TryRead(out var receivedMessage)); + Assert.NotNull(receivedMessage); + Assert.IsType(receivedMessage); + var chineseRequest = (JsonRpcRequest)receivedMessage; + var chineseParams = (Dictionary)chineseRequest.Params!; + Assert.Equal(messageText, (string)chineseParams["text"]); + + await transport.DisposeAsync(); + } + + + [Fact] + public async Task SendMessageAsync_Throws_Exception_If_Not_Connected() + { + // Arrange + await using var transport = InMemoryTransport.Create(NullLoggerFactory.Instance, _toolTypes); + + var message = new JsonRpcRequest { Method = "test" }; + + // Act & Assert + await Assert.ThrowsAsync( + () => transport.SendMessageAsync(message, TestContext.Current.CancellationToken)); + } + + [Fact] + public async Task DisposeAsync_Should_Dispose_Resources() + { + // Arrange + var transport = InMemoryTransport.Create(NullLoggerFactory.Instance, _toolTypes); + IServerTransport serverTransport = transport; + await serverTransport.StartListeningAsync(TestContext.Current.CancellationToken); + + // Act + await transport.DisposeAsync(); + + // Assert + Assert.False(transport.IsConnected); + } + + [McpToolType] + private class TestTool + { + [McpTool] + public string Echo(string message) => message; + } +} From 63d6ed8a2203ad8540a0a59bf4e15d67906a53da Mon Sep 17 00:00:00 2001 From: Diego Colombo Date: Tue, 25 Mar 2025 19:16:03 +0000 Subject: [PATCH 04/13] wait for server task to finish --- .../Protocol/Transport/InMemoryTransport.cs | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/ModelContextProtocol/Protocol/Transport/InMemoryTransport.cs b/src/ModelContextProtocol/Protocol/Transport/InMemoryTransport.cs index fba5427f..e9d7179d 100644 --- a/src/ModelContextProtocol/Protocol/Transport/InMemoryTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/InMemoryTransport.cs @@ -1,4 +1,4 @@ -using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -203,9 +203,9 @@ private async Task ConnectInternalAsync( // Configure server options var serverOptions = new McpServerOptions { - ServerInfo = new Protocol.Types.Implementation { Name = "InMemoryServer", Version = "1.0" }, + ServerInfo = new Types.Implementation { Name = "InMemoryServer", Version = "1.0" }, ProtocolVersion = "2024", - Capabilities = new Protocol.Types.ServerCapabilities() + Capabilities = new Types.ServerCapabilities() }; services.AddOptions().Configure(options => @@ -320,6 +320,20 @@ private async Task CleanupAsync(CancellationToken cancellationToken = default) // Complete the shared channel _sharedChannel.Writer.Complete(); + // now wait for the server task to complete + if (_serverTask != null) + { + try + { + _logger.TransportWaitingForReadTask(_endpointName); + await Task.WhenAny(_serverTask, Task.Delay(500, cancellationToken)); + } + catch (Exception ex) + { + _logger.TransportCleanupReadTaskFailed(_endpointName, ex); + } + } + // Then cancel the server and tasks if (_cancellationTokenSource != null) { From a583bac2c8627ad2cb1d4308d840b17d09de2b8c Mon Sep 17 00:00:00 2001 From: Diego Colombo Date: Tue, 25 Mar 2025 22:42:14 +0000 Subject: [PATCH 05/13] merge main and pr feedback --- .../Transport/InMemoryTransportTests.cs | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs index 765918d8..c83242a9 100644 --- a/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs @@ -1,12 +1,11 @@ -using Microsoft.Extensions.Logging.Abstractions; - using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; namespace ModelContextProtocol.Tests.Transport; -public class InMemoryTransportTests +public class InMemoryTransportTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper) { private readonly Type[] _toolTypes = [typeof(TestTool)]; @@ -14,7 +13,7 @@ public class InMemoryTransportTests public async Task Constructor_Should_Initialize_With_Valid_Parameters() { // Act - await using var transport = InMemoryTransport.Create(NullLoggerFactory.Instance, _toolTypes); + await using var transport = InMemoryTransport.Create(LoggerFactory, _toolTypes); // Assert Assert.NotNull(transport); @@ -24,17 +23,17 @@ public async Task Constructor_Should_Initialize_With_Valid_Parameters() public void Constructor_Throws_For_Null_Parameters() { Assert.Throws("toolTypes", - () => InMemoryTransport.Create(NullLoggerFactory.Instance, Array.Empty())); + () => InMemoryTransport.Create(LoggerFactory, Array.Empty())); Assert.Throws("toolTypes", - () => InMemoryTransport.Create(NullLoggerFactory.Instance, null!)); + () => InMemoryTransport.Create(LoggerFactory, null!)); } [Fact] public async Task ConnectAsync_Should_Set_Connected_State() { // Arrange - var transport = InMemoryTransport.Create(NullLoggerFactory.Instance, _toolTypes); + var transport = InMemoryTransport.Create(LoggerFactory, _toolTypes); // Act var clientTransport = (IClientTransport)transport; @@ -50,7 +49,7 @@ public async Task ConnectAsync_Should_Set_Connected_State() public async Task StartListeningAsync_Should_Set_Connected_State() { // Arrange - await using var transport = InMemoryTransport.Create(NullLoggerFactory.Instance, _toolTypes); + await using var transport = InMemoryTransport.Create(LoggerFactory, _toolTypes); // Act var serverTransport = (IServerTransport)transport; @@ -67,7 +66,7 @@ public async Task StartListeningAsync_Should_Set_Connected_State() public async Task SendMessageAsync_Should_Preserve_Characters(string messageText) { // Arrange - var transport = InMemoryTransport.Create(NullLoggerFactory.Instance, _toolTypes); + var transport = InMemoryTransport.Create(LoggerFactory, _toolTypes); IServerTransport serverTransport = transport; await serverTransport.StartListeningAsync(TestContext.Current.CancellationToken); @@ -103,7 +102,7 @@ public async Task SendMessageAsync_Should_Preserve_Characters(string messageText public async Task SendMessageAsync_Throws_Exception_If_Not_Connected() { // Arrange - await using var transport = InMemoryTransport.Create(NullLoggerFactory.Instance, _toolTypes); + await using var transport = InMemoryTransport.Create(LoggerFactory, _toolTypes); var message = new JsonRpcRequest { Method = "test" }; @@ -116,7 +115,7 @@ await Assert.ThrowsAsync( public async Task DisposeAsync_Should_Dispose_Resources() { // Arrange - var transport = InMemoryTransport.Create(NullLoggerFactory.Instance, _toolTypes); + var transport = InMemoryTransport.Create(LoggerFactory, _toolTypes); IServerTransport serverTransport = transport; await serverTransport.StartListeningAsync(TestContext.Current.CancellationToken); @@ -127,10 +126,10 @@ public async Task DisposeAsync_Should_Dispose_Resources() Assert.False(transport.IsConnected); } - [McpToolType] + [McpServerToolType] private class TestTool { - [McpTool] + [McpServerTool] public string Echo(string message) => message; } } From 90d3aea2de9dc232aac59c985eceac90c5bb7b0a Mon Sep 17 00:00:00 2001 From: Diego Colombo Date: Wed, 26 Mar 2025 04:08:33 +0000 Subject: [PATCH 06/13] implementing in memory transports --- .../McpServerBuilderExtensions.Transports.cs | 19 +- .../Transport/InMemoryClientTransport.cs | 224 +++++++++ .../Transport/InMemoryServerTransport.cs | 212 +++++++++ .../Protocol/Transport/InMemoryTransport.cs | 424 ++---------------- .../Transport/InMemoryTransportTests.cs | 139 +++--- 5 files changed, 548 insertions(+), 470 deletions(-) create mode 100644 src/ModelContextProtocol/Protocol/Transport/InMemoryClientTransport.cs create mode 100644 src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs diff --git a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs index 1f357d32..763ab87d 100644 --- a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs +++ b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs @@ -1,8 +1,9 @@ -using ModelContextProtocol.Configuration; +using Microsoft.Extensions.DependencyInjection; + +using ModelContextProtocol.Configuration; using ModelContextProtocol.Hosting; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Utils; -using Microsoft.Extensions.DependencyInjection; namespace ModelContextProtocol; @@ -11,6 +12,20 @@ namespace ModelContextProtocol; /// public static partial class McpServerBuilderExtensions { + /// + /// Adds a server transport that uses in memory communication. + /// + /// The builder instance. + public static IMcpServerBuilder WithInMemoryServerTransport(this IMcpServerBuilder builder) + { + Throw.IfNull(builder); + var (clientTransport, serverTransport) = InMemoryTransport.Create(); + builder.Services.AddSingleton(serverTransport); + builder.Services.AddSingleton(clientTransport); + builder.Services.AddHostedService(); + return builder; + } + /// /// Adds a server transport that uses stdin/stdout for communication. /// diff --git a/src/ModelContextProtocol/Protocol/Transport/InMemoryClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/InMemoryClientTransport.cs new file mode 100644 index 00000000..3c3a78ca --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Transport/InMemoryClientTransport.cs @@ -0,0 +1,224 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; + +using ModelContextProtocol.Logging; +using ModelContextProtocol.Protocol.Messages; + +using System.Threading.Channels; + +namespace ModelContextProtocol.Protocol.Transport; + +/// +/// Provides an in-memory implementation of the MCP client transport. +/// +public sealed class InMemoryClientTransport : TransportBase, IClientTransport +{ + private readonly string _endpointName = "InMemoryClientTransport"; + private readonly ILogger _logger; + private readonly ChannelWriter _outgoingChannel; + private readonly ChannelReader _incomingChannel; + private CancellationTokenSource? _cancellationTokenSource; + private Task? _readTask; + private SemaphoreSlim _connectLock = new SemaphoreSlim(1, 1); + private volatile bool _disposed; + + /// + /// Gets or sets the server transport this client connects to. + /// + internal InMemoryServerTransport? ServerTransport { get; set; } + + /// + /// Initializes a new instance of the class. + /// + /// Optional logger factory for logging transport operations. + /// Channel for sending messages to the server. + /// Channel for receiving messages from the server. + internal InMemoryClientTransport( + ILoggerFactory? loggerFactory, + ChannelWriter outgoingChannel, + ChannelReader incomingChannel) + : base(loggerFactory) + { + _logger = loggerFactory?.CreateLogger() + ?? NullLogger.Instance; + _outgoingChannel = outgoingChannel; + _incomingChannel = incomingChannel; + } + + + + /// + public async Task ConnectAsync(CancellationToken cancellationToken = default) + { + await _connectLock.WaitAsync(cancellationToken).ConfigureAwait(false); + try + { + ThrowIfDisposed(); + + if (IsConnected) + { + _logger.TransportAlreadyConnected(_endpointName); + throw new McpTransportException("Transport is already connected"); + } + + _logger.TransportConnecting(_endpointName); + + try + { + // Start the server if it exists and is not already connected + if (ServerTransport != null && !ServerTransport.IsConnected) + { + await ServerTransport.StartListeningAsync(cancellationToken).ConfigureAwait(false); + } + + _cancellationTokenSource = new CancellationTokenSource(); + _readTask = Task.Run(() => ReadMessagesAsync(_cancellationTokenSource.Token), _cancellationTokenSource.Token); + + SetConnected(true); + } + catch (Exception ex) + { + _logger.TransportConnectFailed(_endpointName, ex); + await CleanupAsync(cancellationToken).ConfigureAwait(false); + throw new McpTransportException("Failed to connect transport", ex); + } + } + finally + { + _connectLock.Release(); + } + } + + /// + public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + { + ThrowIfDisposed(); + + if (!IsConnected) + { + _logger.TransportNotConnected(_endpointName); + throw new McpTransportException("Transport is not connected"); + } + + string id = "(no id)"; + if (message is IJsonRpcMessageWithId messageWithId) + { + id = messageWithId.Id.ToString(); + } + + try + { + _logger.TransportSendingMessage(_endpointName, id); + await _outgoingChannel.WriteAsync(message, cancellationToken).ConfigureAwait(false); + _logger.TransportSentMessage(_endpointName, id); + } + catch (Exception ex) + { + _logger.TransportSendFailed(_endpointName, id, ex); + throw new McpTransportException("Failed to send message", ex); + } + } + + /// + public override async ValueTask DisposeAsync() + { + await CleanupAsync(CancellationToken.None).ConfigureAwait(false); + GC.SuppressFinalize(this); + } + + private async Task ReadMessagesAsync(CancellationToken cancellationToken) + { + try + { + _logger.TransportEnteringReadMessagesLoop(_endpointName); + + await foreach (var message in _incomingChannel.ReadAllAsync(cancellationToken)) + { + string id = "(no id)"; + if (message is IJsonRpcMessageWithId messageWithId) + { + id = messageWithId.Id.ToString(); + } + + _logger.TransportReceivedMessageParsed(_endpointName, id); + + // Write to the base class's message channel that's exposed via MessageReader + await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); + + _logger.TransportMessageWritten(_endpointName, id); + } + + _logger.TransportExitingReadMessagesLoop(_endpointName); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + _logger.TransportReadMessagesCancelled(_endpointName); + // Normal shutdown + } + catch (Exception ex) + { + _logger.TransportReadMessagesFailed(_endpointName, ex); + } + } + + private async Task CleanupAsync(CancellationToken cancellationToken) + { + if (_disposed) + { + return; + } + + _disposed = true; + _logger.TransportCleaningUp(_endpointName); + + try + { + if (_cancellationTokenSource != null) + { + await _cancellationTokenSource.CancelAsync().ConfigureAwait(false); + _cancellationTokenSource.Dispose(); + _cancellationTokenSource = null; + } + + if (_readTask != null) + { + try + { + _logger.TransportWaitingForReadTask(_endpointName); + await _readTask.WaitAsync(TimeSpan.FromSeconds(1), cancellationToken).ConfigureAwait(false); + } + catch (TimeoutException) + { + _logger.TransportCleanupReadTaskTimeout(_endpointName); + } + catch (OperationCanceledException) + { + _logger.TransportCleanupReadTaskCancelled(_endpointName); + } + catch (Exception ex) + { + _logger.TransportCleanupReadTaskFailed(_endpointName, ex); + } + finally + { + _readTask = null; + } + } + + _connectLock.Dispose(); + } + finally + { + SetConnected(false); + _logger.TransportCleanedUp(_endpointName); + } + } + + private void ThrowIfDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(InMemoryClientTransport)); + } + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs new file mode 100644 index 00000000..627d1dcd --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs @@ -0,0 +1,212 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Logging; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Server; +using System.Diagnostics.CodeAnalysis; +using System.Threading.Channels; + +namespace ModelContextProtocol.Protocol.Transport; + +/// +/// Provides an in-memory implementation of the MCP server transport. +/// +public sealed class InMemoryServerTransport : TransportBase, IServerTransport +{ + private readonly string _endpointName = "InMemoryServerTransport"; + private readonly ILogger _logger; + private readonly ChannelReader _incomingChannel; + private readonly ChannelWriter _outgoingChannel; + private CancellationTokenSource? _cancellationTokenSource; + private Task? _readTask; + private SemaphoreSlim _startLock = new SemaphoreSlim(1, 1); + private volatile bool _disposed; + + /// + /// Initializes a new instance of the class. + /// + /// Optional logger factory for logging transport operations. + /// Channel for receiving messages from the client. + /// Channel for sending messages to the client. + internal InMemoryServerTransport( + ILoggerFactory? loggerFactory, + ChannelReader incomingChannel, + ChannelWriter outgoingChannel) + : base(loggerFactory) + { + _logger = loggerFactory?.CreateLogger() + ?? NullLogger.Instance; + _incomingChannel = incomingChannel; + _outgoingChannel = outgoingChannel; + } + + /// + public async Task StartListeningAsync(CancellationToken cancellationToken = default) + { + await _startLock.WaitAsync(cancellationToken).ConfigureAwait(false); + try + { + ThrowIfDisposed(); + + if (IsConnected) + { + _logger.TransportAlreadyConnected(_endpointName); + throw new McpTransportException("Transport is already connected"); + } + + _logger.TransportConnecting(_endpointName); + + try + { + _cancellationTokenSource = new CancellationTokenSource(); + _readTask = Task.Run(() => ReadMessagesAsync(_cancellationTokenSource.Token), _cancellationTokenSource.Token); + + SetConnected(true); + } + catch (Exception ex) + { + _logger.TransportConnectFailed(_endpointName, ex); + await CleanupAsync(cancellationToken).ConfigureAwait(false); + throw new McpTransportException("Failed to connect transport", ex); + } + } + finally + { + _startLock.Release(); + } + } + + /// + public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + { + ThrowIfDisposed(); + + if (!IsConnected) + { + _logger.TransportNotConnected(_endpointName); + throw new McpTransportException("Transport is not connected"); + } + + string id = "(no id)"; + if (message is IJsonRpcMessageWithId messageWithId) + { + id = messageWithId.Id.ToString(); + } + + try + { + _logger.TransportSendingMessage(_endpointName, id); + await _outgoingChannel.WriteAsync(message, cancellationToken).ConfigureAwait(false); + _logger.TransportSentMessage(_endpointName, id); + } + catch (Exception ex) + { + _logger.TransportSendFailed(_endpointName, id, ex); + throw new McpTransportException("Failed to send message", ex); + } + } + + /// + public override async ValueTask DisposeAsync() + { + await CleanupAsync(CancellationToken.None).ConfigureAwait(false); + GC.SuppressFinalize(this); + } + + private async Task ReadMessagesAsync(CancellationToken cancellationToken) + { + try + { + _logger.TransportEnteringReadMessagesLoop(_endpointName); + + await foreach (var message in _incomingChannel.ReadAllAsync(cancellationToken)) + { + string id = "(no id)"; + if (message is IJsonRpcMessageWithId messageWithId) + { + id = messageWithId.Id.ToString(); + } + + _logger.TransportReceivedMessageParsed(_endpointName, id); + + // Write to the base class's message channel that's exposed via MessageReader + await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); + + _logger.TransportMessageWritten(_endpointName, id); + } + + _logger.TransportExitingReadMessagesLoop(_endpointName); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + _logger.TransportReadMessagesCancelled(_endpointName); + // Normal shutdown + } + catch (Exception ex) + { + _logger.TransportReadMessagesFailed(_endpointName, ex); + } + } + + private async Task CleanupAsync(CancellationToken cancellationToken) + { + if (_disposed) + { + return; + } + + _disposed = true; + _logger.TransportCleaningUp(_endpointName); + + try + { + if (_cancellationTokenSource != null) + { + await _cancellationTokenSource.CancelAsync().ConfigureAwait(false); + _cancellationTokenSource.Dispose(); + _cancellationTokenSource = null; + } + + if (_readTask != null) + { + try + { + _logger.TransportWaitingForReadTask(_endpointName); + await _readTask.WaitAsync(TimeSpan.FromSeconds(1), cancellationToken).ConfigureAwait(false); + } + catch (TimeoutException) + { + _logger.TransportCleanupReadTaskTimeout(_endpointName); + } + catch (OperationCanceledException) + { + _logger.TransportCleanupReadTaskCancelled(_endpointName); + } + catch (Exception ex) + { + _logger.TransportCleanupReadTaskFailed(_endpointName, ex); + } + finally + { + _readTask = null; + } + } + + _startLock.Dispose(); + } + finally + { + SetConnected(false); + _logger.TransportCleanedUp(_endpointName); + } + } + + private void ThrowIfDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(InMemoryServerTransport)); + } + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Transport/InMemoryTransport.cs b/src/ModelContextProtocol/Protocol/Transport/InMemoryTransport.cs index e9d7179d..d6ff0578 100644 --- a/src/ModelContextProtocol/Protocol/Transport/InMemoryTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/InMemoryTransport.cs @@ -1,416 +1,58 @@ -using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; -using ModelContextProtocol.Configuration; -using ModelContextProtocol.Logging; using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Server; -using System.Diagnostics.CodeAnalysis; using System.Threading.Channels; - namespace ModelContextProtocol.Protocol.Transport; /// -/// Provides an in memory implementation of the MCP transport protocol over shared memory. -/// This transport enables efficient in-process MCP functionality and easier test authoring. +/// Factory that creates linked in-memory client and server transports for testing purposes. /// -/// -/// -/// The InMemoryTransport allows both client and server to communicate within the same process, -/// which is particularly useful for testing scenarios and embedded applications. -/// -/// -/// This implementation requires dynamic code access for tool registration and might not work in Native AOT. -/// -/// -[RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)] -public sealed class InMemoryTransport : TransportBase, IServerTransport, IClientTransport +public sealed class InMemoryTransport { - private const string RequiresUnreferencedCodeMessage = "This method requires dynamic lookup of method metadata and might not work in Native AOT."; - - private readonly string _endpointName = "InMemoryTransport"; - private readonly ILogger _logger; - private readonly Type[] _toolTypes; - private readonly Channel _sharedChannel; - private readonly SemaphoreSlim _connectionLock; - private CancellationTokenSource? _cancellationTokenSource; - private Task? _processingTask; - private Task? _serverTask; - private IMcpServer? _server; - private volatile bool _disposed; - /// - /// Initializes a new instance of the class. + /// Creates a new pair of in-memory transports for client and server communication. /// /// Optional logger factory for logging transport operations. - /// The tool types to be registered with the transport. - /// Thrown when no tool types are provided. - private InMemoryTransport( - ILoggerFactory? loggerFactory, - IEnumerable toolTypes) - : base(loggerFactory) + /// A tuple containing client and server transports that communicate with each other. + public static (InMemoryClientTransport ClientTransport, InMemoryServerTransport ServerTransport) Create( + ILoggerFactory? loggerFactory = null) { - var arrayOfToolTypes = toolTypes as Type[] ?? toolTypes.ToArray(); - if (arrayOfToolTypes.Length == 0) - { - throw new ArgumentException( - "At least one tool type must be provided", - nameof(toolTypes)); - } - - _toolTypes = arrayOfToolTypes; - _logger = loggerFactory?.CreateLogger() - ?? NullLogger.Instance; - _connectionLock = new SemaphoreSlim(1, 1); - - var options = new UnboundedChannelOptions + // Configure client-to-server channel - this will be used for: + // 1. Client's outgoing channel + // 2. Server's MessageReader + var clientToServerChannel = Channel.CreateUnbounded(new UnboundedChannelOptions { - SingleReader = true, - SingleWriter = true, + SingleReader = false, // Both server and the server's MessageReader will read + SingleWriter = true, // Client writes AllowSynchronousContinuations = true - }; - _sharedChannel = Channel.CreateUnbounded(options); - } - - /// - /// Creates a new instance of the class. - /// - /// Optional logger factory for logging transport operations. - /// One or more tool types to be registered with the transport. - /// A new instance of . - /// Thrown when no tool types are provided. - /// Thrown when tool types is null. - [RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)] - public static InMemoryTransport Create( - ILoggerFactory? loggerFactory, - params IEnumerable[] toolTypes) - { - if (toolTypes is null) - { - throw new ArgumentNullException(nameof(toolTypes)); - } + }); - if (toolTypes.Length == 0) + // Configure server-to-client channel - this will be used for: + // 1. Server's outgoing channel + // 2. Client's MessageReader + var serverToClientChannel = Channel.CreateUnbounded(new UnboundedChannelOptions { - throw new ArgumentException( - "At least one tool type enumerable must be provided", - nameof(toolTypes)); - } + SingleReader = false, // Both client and the client's MessageReader will read + SingleWriter = true, // Server writes + AllowSynchronousContinuations = true + }); - return new InMemoryTransport( + // Create the client and server transports - they directly expose the channels through MessageReader + var serverTransport = new InMemoryServerTransport( loggerFactory, - toolTypes.SelectMany(x => x)); - } - - /// - public override async Task SendMessageAsync( - IJsonRpcMessage message, - CancellationToken cancellationToken = default) - { - // During disposal, allow final message to complete if we're still connected - if (_disposed && !IsConnected) - { - throw new ObjectDisposedException(nameof(InMemoryTransport)); - } - - if (!IsConnected) - { - _logger.TransportNotConnected(_endpointName); - throw new McpTransportException("Transport is not connected"); - } - - string id = "(no id)"; - if (message is IJsonRpcMessageWithId messageWithId) - { - id = messageWithId.Id.ToString(); - } - - try - { - _logger.TransportSendingMessage(_endpointName, id); - - // Only write to shared channel, let HandleMessageReceivedAsync handle base channel write - await _sharedChannel.Writer.WriteAsync(message, cancellationToken); - - // Wait briefly for the message to be processed if we're disposing - if (_disposed) - { - await Task.Delay(50, cancellationToken); - } - - _logger.TransportSentMessage(_endpointName, id); - } - catch (Exception ex) - { - _logger.TransportSendFailed(_endpointName, id, ex); - throw new McpTransportException("Failed to send message", ex); - } - } - - /// - Task IClientTransport.ConnectAsync(CancellationToken cancellationToken) - { - ThrowIfDisposed(); -#pragma warning disable IL2026 - return ConnectInternalAsync(cancellationToken); -#pragma warning restore IL2026 - } - - /// - Task IServerTransport.StartListeningAsync(CancellationToken cancellationToken) - { - ThrowIfDisposed(); -#pragma warning disable IL2026 - return ConnectInternalAsync(cancellationToken); -#pragma warning restore IL2026 - } - - /// - public override async ValueTask DisposeAsync() - { - await CleanupAsync(CancellationToken.None).ConfigureAwait(false); - } - - /// - /// Connects the transport and initializes the server. - /// - /// A token to observe for cancellation requests. - /// A task representing the asynchronous operation. - /// Thrown when the transport is already connected or when connection fails. - [RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)] - private async Task ConnectInternalAsync( - CancellationToken cancellationToken) - { - await _connectionLock.WaitAsync(cancellationToken); - try - { - if (IsConnected) - { - _logger.TransportAlreadyConnected(_endpointName); - throw new McpTransportException("Transport is already connected"); - } - - _logger.TransportConnecting(_endpointName); + clientToServerChannel.Reader, // incoming: reads messages from client + serverToClientChannel.Writer); // outgoing: writes messages to client - // Create a service collection and builder to set up the in-memory server - var services = new ServiceCollection(); - services.AddSingleton(this); - services.AddSingleton(this); - - // Configure server options - var serverOptions = new McpServerOptions - { - ServerInfo = new Types.Implementation { Name = "InMemoryServer", Version = "1.0" }, - ProtocolVersion = "2024", - Capabilities = new Types.ServerCapabilities() - }; - - services.AddOptions().Configure(options => - { - options.ServerInfo = serverOptions.ServerInfo; - options.ProtocolVersion = serverOptions.ProtocolVersion; - options.Capabilities = serverOptions.Capabilities; - }); - - // Create a server builder to register the tools - var builder = new DefaultMcpServerBuilder(services); - - // Register the provided tool types - if (_toolTypes.Length > 0) - { - _logger.LogDebug("Registering {Count} tool types", _toolTypes.Length); - builder.WithTools(_toolTypes); - } - - try - { - // Create IServiceProvider instance manually - var serviceProvider = services.BuildServiceProvider(); - - // Create and initialize the server using a logger factory - var loggerFactory = NullLoggerFactory.Instance; - _server = McpServerFactory.Create(this, serverOptions, loggerFactory, serviceProvider); - - // Create cancellation source to manage all tasks - _cancellationTokenSource = new CancellationTokenSource(); - - // Start the server as fire-and-forget (don't await) - _logger.LogDebug("Starting server (fire-and-forget)"); - _serverTask = _server.StartAsync(_cancellationTokenSource.Token); - - // Start a background task to process messages from the shared channel - _processingTask = Task.Run(async () => - { - try - { - _logger.LogDebug("Starting message processing task"); - await foreach (var message in _sharedChannel.Reader.ReadAllAsync(_cancellationTokenSource.Token)) - { - _logger.LogTrace("Transport reading message from channel: {Message}", message); - await HandleMessageReceivedAsync(message, _cancellationTokenSource.Token); - } - } - catch (OperationCanceledException) when (_cancellationTokenSource.Token.IsCancellationRequested) - { - // Expected when cancellation is requested - _logger.TransportReadMessagesCancelled(_endpointName); - } - catch (Exception ex) - { - _logger.TransportReadMessagesFailed(_endpointName, ex); - } - }, _cancellationTokenSource.Token); - - // Short delay to allow background tasks to start - await Task.Delay(10, cancellationToken); - - // Only set connected if initialization succeeded - SetConnected(true); - _logger.LogDebug("Transport connected for {EndpointName}", _endpointName); - } - catch (Exception ex) - { - _logger.TransportConnectFailed(_endpointName, ex); - await CleanupAsync(cancellationToken); - throw new McpTransportException("Failed to connect transport", ex); - } - } - finally - { - _connectionLock.Release(); - } - } - - /// - /// Cleans up resources used by the transport. - /// - /// A token to observe for cancellation requests. - /// A task representing the asynchronous operation. - private async Task CleanupAsync(CancellationToken cancellationToken = default) - { - if (_disposed) - { - return; - } - - _logger.TransportCleaningUp(_endpointName); - - try - { - // Mark as disposed to prevent new operations but keep connection alive - _disposed = true; - - // First wait for the processing task to complete - if (_processingTask != null) - { - try - { - _logger.TransportWaitingForReadTask(_endpointName); - await Task.WhenAny(_processingTask, Task.Delay(500, cancellationToken)); - } - catch (Exception ex) - { - _logger.TransportCleanupReadTaskFailed(_endpointName, ex); - } - } - - // Complete the shared channel - _sharedChannel.Writer.Complete(); - - // now wait for the server task to complete - if (_serverTask != null) - { - try - { - _logger.TransportWaitingForReadTask(_endpointName); - await Task.WhenAny(_serverTask, Task.Delay(500, cancellationToken)); - } - catch (Exception ex) - { - _logger.TransportCleanupReadTaskFailed(_endpointName, ex); - } - } - - // Then cancel the server and tasks - if (_cancellationTokenSource != null) - { - await _cancellationTokenSource.CancelAsync(); - _cancellationTokenSource.Dispose(); - _cancellationTokenSource = null; - _processingTask = null; - _serverTask = null; - } - - // Dispose server with timeout - if (_server != null) - { - try - { - await _server.DisposeAsync().ConfigureAwait(false); - } - catch (OperationCanceledException) - { - // Timeout is acceptable - } - catch (Exception ex) - { - _logger.TransportShutdownFailed(_endpointName, ex); - } - _server = null; - } - - // Dispose connection lock - _connectionLock.Dispose(); - } - finally - { - // Set connected to false last - SetConnected(false); - _logger.TransportCleanedUp(_endpointName); - } - } - - /// - /// Throws an if the transport has been disposed. - /// - /// Thrown when the transport has been disposed. - private void ThrowIfDisposed() - { - if (_disposed) - { - throw new ObjectDisposedException(nameof(InMemoryTransport)); - } - } - - /// - /// Processes a message received from the shared channel. - /// - /// The message to process. - /// A token to observe for cancellation requests. - /// A task representing the asynchronous operation. - private async Task HandleMessageReceivedAsync(IJsonRpcMessage message, CancellationToken cancellationToken) - { - string id = "(no id)"; - if (message is IJsonRpcMessageWithId messageWithId) - { - id = messageWithId.Id.ToString(); - } - - try - { - _logger.TransportReceivedMessageParsed(_endpointName, id); + var clientTransport = new InMemoryClientTransport( + loggerFactory, + clientToServerChannel.Writer, // outgoing: writes messages to server + serverToClientChannel.Reader); // incoming: reads messages from server - // Write the message to the base class message channel - await base.WriteMessageAsync(message, cancellationToken); + // Link the transports together + clientTransport.ServerTransport = serverTransport; - _logger.TransportMessageWritten(_endpointName, id); - } - catch (Exception ex) - { - _logger.LogError(ex, "Transport message write failed for {EndpointName} with ID {MessageId}", _endpointName, id); - } + return (clientTransport, serverTransport); } -} +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs index c83242a9..aa9209eb 100644 --- a/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs @@ -10,122 +10,107 @@ public class InMemoryTransportTests(ITestOutputHelper testOutputHelper) : Logged private readonly Type[] _toolTypes = [typeof(TestTool)]; [Fact] - public async Task Constructor_Should_Initialize_With_Valid_Parameters() + public async Task TransportPair_Should_Create_Valid_Transports() { - // Act - await using var transport = InMemoryTransport.Create(LoggerFactory, _toolTypes); + // Act - create a transport pair + var (clientTransport, serverTransport) = InMemoryTransport.Create(LoggerFactory); // Assert - Assert.NotNull(transport); + Assert.NotNull(clientTransport); + Assert.NotNull(serverTransport); + Assert.False(clientTransport.IsConnected); + Assert.False(serverTransport.IsConnected); + + // Cleanup + await clientTransport.DisposeAsync(); + await serverTransport.DisposeAsync(); } [Fact] - public void Constructor_Throws_For_Null_Parameters() - { - Assert.Throws("toolTypes", - () => InMemoryTransport.Create(LoggerFactory, Array.Empty())); - - Assert.Throws("toolTypes", - () => InMemoryTransport.Create(LoggerFactory, null!)); - } - - [Fact] - public async Task ConnectAsync_Should_Set_Connected_State() + public async Task ClientConnect_Should_StartServer_And_SetConnected() { // Arrange - var transport = InMemoryTransport.Create(LoggerFactory, _toolTypes); + var (clientTransport, serverTransport) = InMemoryTransport.Create(LoggerFactory); // Act - var clientTransport = (IClientTransport)transport; await clientTransport.ConnectAsync(TestContext.Current.CancellationToken); // Assert - Assert.True(transport.IsConnected); + Assert.True(clientTransport.IsConnected); + Assert.True(serverTransport.IsConnected); - await transport.DisposeAsync(); + // Cleanup + await clientTransport.DisposeAsync(); + await serverTransport.DisposeAsync(); } [Fact] - public async Task StartListeningAsync_Should_Set_Connected_State() - { - // Arrange - await using var transport = InMemoryTransport.Create(LoggerFactory, _toolTypes); - - // Act - var serverTransport = (IServerTransport)transport; - await serverTransport.StartListeningAsync(TestContext.Current.CancellationToken); - - // Assert - Assert.True(transport.IsConnected); - } - - [Theory] - [InlineData("Hello, World!")] - [InlineData("上下文伺服器")] - [InlineData("🔍 🚀 👍")] - public async Task SendMessageAsync_Should_Preserve_Characters(string messageText) + public async Task Message_Should_Flow_From_Client_To_Server() { // Arrange - var transport = InMemoryTransport.Create(LoggerFactory, _toolTypes); - - IServerTransport serverTransport = transport; - await serverTransport.StartListeningAsync(TestContext.Current.CancellationToken); - - // Ensure transport is fully initialized - await Task.Delay(100, TestContext.Current.CancellationToken); + var (clientTransport, serverTransport) = InMemoryTransport.Create(LoggerFactory); + await clientTransport.ConnectAsync(TestContext.Current.CancellationToken); - var chineseMessage = new JsonRpcRequest + var message = new JsonRpcRequest { Method = "test", - Id = RequestId.FromNumber(44), - Params = new Dictionary - { - ["text"] = messageText - } + Id = RequestId.FromNumber(123), + Params = new Dictionary { ["text"] = "Hello, World!" } }; - // Act & Assert - Chinese - await transport.SendMessageAsync(chineseMessage, TestContext.Current.CancellationToken); - - Assert.True(transport.MessageReader.TryRead(out var receivedMessage)); + // Act + await clientTransport.SendMessageAsync(message, TestContext.Current.CancellationToken); + await Task.Delay(1, TestContext.Current.CancellationToken); + // Assert + Assert.True(serverTransport.MessageReader.TryRead(out var receivedMessage)); Assert.NotNull(receivedMessage); Assert.IsType(receivedMessage); - var chineseRequest = (JsonRpcRequest)receivedMessage; - var chineseParams = (Dictionary)chineseRequest.Params!; - Assert.Equal(messageText, (string)chineseParams["text"]); - await transport.DisposeAsync(); - } + var request = (JsonRpcRequest)receivedMessage; + Assert.Equal(123, request.Id.AsNumber); + Assert.Equal("test", request.Method); + var requestParams = (Dictionary)request.Params!; + Assert.Equal("Hello, World!", requestParams["text"]); - [Fact] - public async Task SendMessageAsync_Throws_Exception_If_Not_Connected() - { - // Arrange - await using var transport = InMemoryTransport.Create(LoggerFactory, _toolTypes); - - var message = new JsonRpcRequest { Method = "test" }; - - // Act & Assert - await Assert.ThrowsAsync( - () => transport.SendMessageAsync(message, TestContext.Current.CancellationToken)); + // Cleanup + await clientTransport.DisposeAsync(); + await serverTransport.DisposeAsync(); } [Fact] - public async Task DisposeAsync_Should_Dispose_Resources() + public async Task Message_Should_Flow_From_Server_To_Client() { // Arrange - var transport = InMemoryTransport.Create(LoggerFactory, _toolTypes); - IServerTransport serverTransport = transport; - await serverTransport.StartListeningAsync(TestContext.Current.CancellationToken); + var (clientTransport, serverTransport) = InMemoryTransport.Create(LoggerFactory); + await clientTransport.ConnectAsync(TestContext.Current.CancellationToken); - // Act - await transport.DisposeAsync(); + var message = new JsonRpcResponse + { + Id = RequestId.FromNumber(456), + Result = new Dictionary { ["text"] = "Response from server" } + }; + // Act + await serverTransport.SendMessageAsync(message, TestContext.Current.CancellationToken); + await Task.Delay(1, TestContext.Current.CancellationToken); // Assert - Assert.False(transport.IsConnected); + Assert.True(clientTransport.MessageReader.TryRead(out var receivedMessage)); + Assert.NotNull(receivedMessage); + Assert.IsType(receivedMessage); + + var response = (JsonRpcResponse)receivedMessage; + Assert.Equal(456, response.Id.AsNumber); + + var responseResult = (Dictionary)response.Result!; + Assert.Equal("Response from server", responseResult["text"]); + + // Cleanup + await clientTransport.DisposeAsync(); + await serverTransport.DisposeAsync(); } + [McpServerToolType] private class TestTool { From 3fe434535ddb5f286aab2861b7f842f28d00924f Mon Sep 17 00:00:00 2001 From: Diego Colombo Date: Wed, 26 Mar 2025 04:25:59 +0000 Subject: [PATCH 07/13] normalise endings --- .../Protocol/Types/RequestParams.cs | 32 +++++++++---------- .../Protocol/Types/RequestParamsMetadata.cs | 30 ++++++++--------- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/src/ModelContextProtocol/Protocol/Types/RequestParams.cs b/src/ModelContextProtocol/Protocol/Types/RequestParams.cs index 73dbc2a5..1185bd08 100644 --- a/src/ModelContextProtocol/Protocol/Types/RequestParams.cs +++ b/src/ModelContextProtocol/Protocol/Types/RequestParams.cs @@ -1,16 +1,16 @@ -using System.Text.Json.Serialization; - -namespace ModelContextProtocol.Protocol.Types; - -/// -/// Base class for all request parameters. -/// See the schema for details -/// -public abstract class RequestParams -{ - /// - /// Metadata related to the tool invocation. - /// - [JsonPropertyName("_meta")] - public RequestParamsMetadata? Meta { get; init; } -} +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol.Types; + +/// +/// Base class for all request parameters. +/// See the schema for details +/// +public abstract class RequestParams +{ + /// + /// Metadata related to the tool invocation. + /// + [JsonPropertyName("_meta")] + public RequestParamsMetadata? Meta { get; init; } +} diff --git a/src/ModelContextProtocol/Protocol/Types/RequestParamsMetadata.cs b/src/ModelContextProtocol/Protocol/Types/RequestParamsMetadata.cs index 919fb8ff..151064e4 100644 --- a/src/ModelContextProtocol/Protocol/Types/RequestParamsMetadata.cs +++ b/src/ModelContextProtocol/Protocol/Types/RequestParamsMetadata.cs @@ -1,15 +1,15 @@ -using System.Text.Json.Serialization; - -namespace ModelContextProtocol.Protocol.Types; - -/// -/// Metadata related to the request. -/// -public class RequestParamsMetadata -{ - /// - /// If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications. - /// - [JsonPropertyName("progressToken")] - public object ProgressToken { get; set; } = default!; -} +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol.Types; + +/// +/// Metadata related to the request. +/// +public class RequestParamsMetadata +{ + /// + /// If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications. + /// + [JsonPropertyName("progressToken")] + public object ProgressToken { get; set; } = default!; +} From e078d36d0580365c5229ba3e9139cd47a8d78695 Mon Sep 17 00:00:00 2001 From: Diego Colombo Date: Wed, 26 Mar 2025 04:31:54 +0000 Subject: [PATCH 08/13] test transports --- .../Transport/InMemoryTransportTests.cs | 60 ++++++++----------- 1 file changed, 26 insertions(+), 34 deletions(-) diff --git a/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs index aa9209eb..ec1cd752 100644 --- a/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs @@ -1,53 +1,53 @@ using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; -using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Utils; namespace ModelContextProtocol.Tests.Transport; public class InMemoryTransportTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper) { - private readonly Type[] _toolTypes = [typeof(TestTool)]; [Fact] - public async Task TransportPair_Should_Create_Valid_Transports() + public async Task SendMessageAsync_Throws_Exception_If_Not_Connected() { - // Act - create a transport pair var (clientTransport, serverTransport) = InMemoryTransport.Create(LoggerFactory); - // Assert - Assert.NotNull(clientTransport); - Assert.NotNull(serverTransport); - Assert.False(clientTransport.IsConnected); - Assert.False(serverTransport.IsConnected); + var message = new JsonRpcRequest { Method = "test" }; - // Cleanup - await clientTransport.DisposeAsync(); - await serverTransport.DisposeAsync(); + await Assert.ThrowsAsync(() => serverTransport.SendMessageAsync(message, TestContext.Current.CancellationToken)); + await Assert.ThrowsAsync(() => clientTransport.SendMessageAsync(message, TestContext.Current.CancellationToken)); } [Fact] - public async Task ClientConnect_Should_StartServer_And_SetConnected() + public async Task DisposeAsync_Should_Dispose_Resources() { - // Arrange var (clientTransport, serverTransport) = InMemoryTransport.Create(LoggerFactory); - // Act - await clientTransport.ConnectAsync(TestContext.Current.CancellationToken); + await serverTransport.DisposeAsync(); + await clientTransport.DisposeAsync(); - // Assert - Assert.True(clientTransport.IsConnected); - Assert.True(serverTransport.IsConnected); + Assert.False(serverTransport.IsConnected); + Assert.False(clientTransport.IsConnected); + } + + [Fact] + public async Task TransportPair_Should_Create_Valid_Transports() + { + var (clientTransport, serverTransport) = InMemoryTransport.Create(LoggerFactory); + + Assert.NotNull(clientTransport); + Assert.NotNull(serverTransport); + Assert.False(clientTransport.IsConnected); + Assert.False(serverTransport.IsConnected); - // Cleanup await clientTransport.DisposeAsync(); await serverTransport.DisposeAsync(); } + [Fact] public async Task Message_Should_Flow_From_Client_To_Server() { - // Arrange var (clientTransport, serverTransport) = InMemoryTransport.Create(LoggerFactory); await clientTransport.ConnectAsync(TestContext.Current.CancellationToken); @@ -58,10 +58,11 @@ public async Task Message_Should_Flow_From_Client_To_Server() Params = new Dictionary { ["text"] = "Hello, World!" } }; - // Act + await clientTransport.SendMessageAsync(message, TestContext.Current.CancellationToken); await Task.Delay(1, TestContext.Current.CancellationToken); - // Assert + + Assert.True(serverTransport.MessageReader.TryRead(out var receivedMessage)); Assert.NotNull(receivedMessage); Assert.IsType(receivedMessage); @@ -81,7 +82,6 @@ public async Task Message_Should_Flow_From_Client_To_Server() [Fact] public async Task Message_Should_Flow_From_Server_To_Client() { - // Arrange var (clientTransport, serverTransport) = InMemoryTransport.Create(LoggerFactory); await clientTransport.ConnectAsync(TestContext.Current.CancellationToken); @@ -91,10 +91,10 @@ public async Task Message_Should_Flow_From_Server_To_Client() Result = new Dictionary { ["text"] = "Response from server" } }; - // Act + await serverTransport.SendMessageAsync(message, TestContext.Current.CancellationToken); await Task.Delay(1, TestContext.Current.CancellationToken); - // Assert + Assert.True(clientTransport.MessageReader.TryRead(out var receivedMessage)); Assert.NotNull(receivedMessage); Assert.IsType(receivedMessage); @@ -109,12 +109,4 @@ public async Task Message_Should_Flow_From_Server_To_Client() await clientTransport.DisposeAsync(); await serverTransport.DisposeAsync(); } - - - [McpServerToolType] - private class TestTool - { - [McpServerTool] - public string Echo(string message) => message; - } } From 3aae15c9f8484224a1fd5bb135a3ef62150d307c Mon Sep 17 00:00:00 2001 From: Diego Colombo Date: Wed, 26 Mar 2025 05:00:49 +0000 Subject: [PATCH 09/13] use names from mcpserveroption --- .../McpServerBuilderExtensions.Transports.cs | 16 ++++- .../Transport/InMemoryClientTransport.cs | 52 +++++++------- .../Transport/InMemoryServerTransport.cs | 55 +++++++------- .../Protocol/Transport/InMemoryTransport.cs | 71 +++++++++++++++++-- .../Transport/InMemoryTransportTests.cs | 22 ++++-- .../Transport/StdioServerTransportTests.cs | 55 +++++++------- 6 files changed, 182 insertions(+), 89 deletions(-) diff --git a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs index 763ab87d..06953962 100644 --- a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs +++ b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs @@ -19,9 +19,19 @@ public static partial class McpServerBuilderExtensions public static IMcpServerBuilder WithInMemoryServerTransport(this IMcpServerBuilder builder) { Throw.IfNull(builder); - var (clientTransport, serverTransport) = InMemoryTransport.Create(); - builder.Services.AddSingleton(serverTransport); - builder.Services.AddSingleton(clientTransport); + builder.Services.AddSingleton(); + builder.Services.AddSingleton(s => + { + var transport = s.GetRequiredService(); + return transport.ClientTransport; + }); + + builder.Services.AddSingleton(s => + { + var transport = s.GetRequiredService(); + return transport.ServerTransport; + }); + builder.Services.AddHostedService(); return builder; } diff --git a/src/ModelContextProtocol/Protocol/Transport/InMemoryClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/InMemoryClientTransport.cs index 3c3a78ca..754309be 100644 --- a/src/ModelContextProtocol/Protocol/Transport/InMemoryClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/InMemoryClientTransport.cs @@ -13,16 +13,17 @@ namespace ModelContextProtocol.Protocol.Transport; /// public sealed class InMemoryClientTransport : TransportBase, IClientTransport { - private readonly string _endpointName = "InMemoryClientTransport"; + private string EndpointName => $"Client (in memory) for ({_serverName})"; private readonly ILogger _logger; + private readonly string _serverName; private readonly ChannelWriter _outgoingChannel; private readonly ChannelReader _incomingChannel; private CancellationTokenSource? _cancellationTokenSource; private Task? _readTask; - private SemaphoreSlim _connectLock = new SemaphoreSlim(1, 1); + private readonly SemaphoreSlim _connectLock = new SemaphoreSlim(1, 1); private volatile bool _disposed; - /// + /// /// Gets or sets the server transport this client connects to. /// internal InMemoryServerTransport? ServerTransport { get; set; } @@ -30,10 +31,12 @@ public sealed class InMemoryClientTransport : TransportBase, IClientTransport /// /// Initializes a new instance of the class. /// + /// The name of the server. /// Optional logger factory for logging transport operations. /// Channel for sending messages to the server. /// Channel for receiving messages from the server. internal InMemoryClientTransport( + string serverName, ILoggerFactory? loggerFactory, ChannelWriter outgoingChannel, ChannelReader incomingChannel) @@ -41,6 +44,7 @@ internal InMemoryClientTransport( { _logger = loggerFactory?.CreateLogger() ?? NullLogger.Instance; + _serverName = serverName; _outgoingChannel = outgoingChannel; _incomingChannel = incomingChannel; } @@ -57,11 +61,11 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) if (IsConnected) { - _logger.TransportAlreadyConnected(_endpointName); + _logger.TransportAlreadyConnected(EndpointName); throw new McpTransportException("Transport is already connected"); } - _logger.TransportConnecting(_endpointName); + _logger.TransportConnecting(EndpointName); try { @@ -78,7 +82,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) } catch (Exception ex) { - _logger.TransportConnectFailed(_endpointName, ex); + _logger.TransportConnectFailed(EndpointName, ex); await CleanupAsync(cancellationToken).ConfigureAwait(false); throw new McpTransportException("Failed to connect transport", ex); } @@ -96,7 +100,7 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio if (!IsConnected) { - _logger.TransportNotConnected(_endpointName); + _logger.TransportNotConnected(EndpointName); throw new McpTransportException("Transport is not connected"); } @@ -108,13 +112,13 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio try { - _logger.TransportSendingMessage(_endpointName, id); + _logger.TransportSendingMessage(EndpointName, id); await _outgoingChannel.WriteAsync(message, cancellationToken).ConfigureAwait(false); - _logger.TransportSentMessage(_endpointName, id); + _logger.TransportSentMessage(EndpointName, id); } catch (Exception ex) { - _logger.TransportSendFailed(_endpointName, id, ex); + _logger.TransportSendFailed(EndpointName, id, ex); throw new McpTransportException("Failed to send message", ex); } } @@ -130,7 +134,7 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken) { try { - _logger.TransportEnteringReadMessagesLoop(_endpointName); + _logger.TransportEnteringReadMessagesLoop(EndpointName); await foreach (var message in _incomingChannel.ReadAllAsync(cancellationToken)) { @@ -140,24 +144,24 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken) id = messageWithId.Id.ToString(); } - _logger.TransportReceivedMessageParsed(_endpointName, id); - + _logger.TransportReceivedMessageParsed(EndpointName, id); + // Write to the base class's message channel that's exposed via MessageReader await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); - - _logger.TransportMessageWritten(_endpointName, id); + + _logger.TransportMessageWritten(EndpointName, id); } - _logger.TransportExitingReadMessagesLoop(_endpointName); + _logger.TransportExitingReadMessagesLoop(EndpointName); } catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) { - _logger.TransportReadMessagesCancelled(_endpointName); + _logger.TransportReadMessagesCancelled(EndpointName); // Normal shutdown } catch (Exception ex) { - _logger.TransportReadMessagesFailed(_endpointName, ex); + _logger.TransportReadMessagesFailed(EndpointName, ex); } } @@ -169,7 +173,7 @@ private async Task CleanupAsync(CancellationToken cancellationToken) } _disposed = true; - _logger.TransportCleaningUp(_endpointName); + _logger.TransportCleaningUp(EndpointName); try { @@ -184,20 +188,20 @@ private async Task CleanupAsync(CancellationToken cancellationToken) { try { - _logger.TransportWaitingForReadTask(_endpointName); + _logger.TransportWaitingForReadTask(EndpointName); await _readTask.WaitAsync(TimeSpan.FromSeconds(1), cancellationToken).ConfigureAwait(false); } catch (TimeoutException) { - _logger.TransportCleanupReadTaskTimeout(_endpointName); + _logger.TransportCleanupReadTaskTimeout(EndpointName); } catch (OperationCanceledException) { - _logger.TransportCleanupReadTaskCancelled(_endpointName); + _logger.TransportCleanupReadTaskCancelled(EndpointName); } catch (Exception ex) { - _logger.TransportCleanupReadTaskFailed(_endpointName, ex); + _logger.TransportCleanupReadTaskFailed(EndpointName, ex); } finally { @@ -210,7 +214,7 @@ private async Task CleanupAsync(CancellationToken cancellationToken) finally { SetConnected(false); - _logger.TransportCleanedUp(_endpointName); + _logger.TransportCleanedUp(EndpointName); } } diff --git a/src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs index 627d1dcd..9c82e2b3 100644 --- a/src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs @@ -1,10 +1,9 @@ -using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; + using ModelContextProtocol.Logging; using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Server; -using System.Diagnostics.CodeAnalysis; + using System.Threading.Channels; namespace ModelContextProtocol.Protocol.Transport; @@ -14,7 +13,7 @@ namespace ModelContextProtocol.Protocol.Transport; /// public sealed class InMemoryServerTransport : TransportBase, IServerTransport { - private readonly string _endpointName = "InMemoryServerTransport"; + private string EndpointName => $"Server (in memory) for ({_serverName})"; private readonly ILogger _logger; private readonly ChannelReader _incomingChannel; private readonly ChannelWriter _outgoingChannel; @@ -22,23 +21,27 @@ public sealed class InMemoryServerTransport : TransportBase, IServerTransport private Task? _readTask; private SemaphoreSlim _startLock = new SemaphoreSlim(1, 1); private volatile bool _disposed; + private readonly string _serverName; /// /// Initializes a new instance of the class. /// + /// The name of the server. /// Optional logger factory for logging transport operations. /// Channel for receiving messages from the client. /// Channel for sending messages to the client. internal InMemoryServerTransport( + string serverName, ILoggerFactory? loggerFactory, ChannelReader incomingChannel, ChannelWriter outgoingChannel) : base(loggerFactory) { - _logger = loggerFactory?.CreateLogger() + _logger = loggerFactory?.CreateLogger() ?? NullLogger.Instance; _incomingChannel = incomingChannel; _outgoingChannel = outgoingChannel; + _serverName = serverName; } /// @@ -51,11 +54,11 @@ public async Task StartListeningAsync(CancellationToken cancellationToken = defa if (IsConnected) { - _logger.TransportAlreadyConnected(_endpointName); + _logger.TransportAlreadyConnected(EndpointName); throw new McpTransportException("Transport is already connected"); } - _logger.TransportConnecting(_endpointName); + _logger.TransportConnecting(EndpointName); try { @@ -66,7 +69,7 @@ public async Task StartListeningAsync(CancellationToken cancellationToken = defa } catch (Exception ex) { - _logger.TransportConnectFailed(_endpointName, ex); + _logger.TransportConnectFailed(EndpointName, ex); await CleanupAsync(cancellationToken).ConfigureAwait(false); throw new McpTransportException("Failed to connect transport", ex); } @@ -84,7 +87,7 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio if (!IsConnected) { - _logger.TransportNotConnected(_endpointName); + _logger.TransportNotConnected(EndpointName); throw new McpTransportException("Transport is not connected"); } @@ -96,13 +99,13 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio try { - _logger.TransportSendingMessage(_endpointName, id); + _logger.TransportSendingMessage(EndpointName, id); await _outgoingChannel.WriteAsync(message, cancellationToken).ConfigureAwait(false); - _logger.TransportSentMessage(_endpointName, id); + _logger.TransportSentMessage(EndpointName, id); } catch (Exception ex) { - _logger.TransportSendFailed(_endpointName, id, ex); + _logger.TransportSendFailed(EndpointName, id, ex); throw new McpTransportException("Failed to send message", ex); } } @@ -118,7 +121,7 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken) { try { - _logger.TransportEnteringReadMessagesLoop(_endpointName); + _logger.TransportEnteringReadMessagesLoop(EndpointName); await foreach (var message in _incomingChannel.ReadAllAsync(cancellationToken)) { @@ -128,24 +131,24 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken) id = messageWithId.Id.ToString(); } - _logger.TransportReceivedMessageParsed(_endpointName, id); - + _logger.TransportReceivedMessageParsed(EndpointName, id); + // Write to the base class's message channel that's exposed via MessageReader await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); - - _logger.TransportMessageWritten(_endpointName, id); + + _logger.TransportMessageWritten(EndpointName, id); } - _logger.TransportExitingReadMessagesLoop(_endpointName); + _logger.TransportExitingReadMessagesLoop(EndpointName); } catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) { - _logger.TransportReadMessagesCancelled(_endpointName); + _logger.TransportReadMessagesCancelled(EndpointName); // Normal shutdown } catch (Exception ex) { - _logger.TransportReadMessagesFailed(_endpointName, ex); + _logger.TransportReadMessagesFailed(EndpointName, ex); } } @@ -157,7 +160,7 @@ private async Task CleanupAsync(CancellationToken cancellationToken) } _disposed = true; - _logger.TransportCleaningUp(_endpointName); + _logger.TransportCleaningUp(EndpointName); try { @@ -172,20 +175,20 @@ private async Task CleanupAsync(CancellationToken cancellationToken) { try { - _logger.TransportWaitingForReadTask(_endpointName); + _logger.TransportWaitingForReadTask(EndpointName); await _readTask.WaitAsync(TimeSpan.FromSeconds(1), cancellationToken).ConfigureAwait(false); } catch (TimeoutException) { - _logger.TransportCleanupReadTaskTimeout(_endpointName); + _logger.TransportCleanupReadTaskTimeout(EndpointName); } catch (OperationCanceledException) { - _logger.TransportCleanupReadTaskCancelled(_endpointName); + _logger.TransportCleanupReadTaskCancelled(EndpointName); } catch (Exception ex) { - _logger.TransportCleanupReadTaskFailed(_endpointName, ex); + _logger.TransportCleanupReadTaskFailed(EndpointName, ex); } finally { @@ -198,7 +201,7 @@ private async Task CleanupAsync(CancellationToken cancellationToken) finally { SetConnected(false); - _logger.TransportCleanedUp(_endpointName); + _logger.TransportCleanedUp(EndpointName); } } diff --git a/src/ModelContextProtocol/Protocol/Transport/InMemoryTransport.cs b/src/ModelContextProtocol/Protocol/Transport/InMemoryTransport.cs index d6ff0578..3119b868 100644 --- a/src/ModelContextProtocol/Protocol/Transport/InMemoryTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/InMemoryTransport.cs @@ -1,6 +1,9 @@ using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Server; +using ModelContextProtocol.Utils; using System.Threading.Channels; @@ -12,11 +15,61 @@ namespace ModelContextProtocol.Protocol.Transport; public sealed class InMemoryTransport { /// - /// Creates a new pair of in-memory transports for client and server communication. + /// Initializes a new instance of the class. /// - /// Optional logger factory for logging transport operations. - /// A tuple containing client and server transports that communicate with each other. - public static (InMemoryClientTransport ClientTransport, InMemoryServerTransport ServerTransport) Create( + /// The server options. + /// Optional logger factory used for logging employed by the transport. + /// is or contains a null name. + + public InMemoryTransport(McpServerOptions serverOptions, ILoggerFactory? loggerFactory = null) + : this(GetServerName(serverOptions), loggerFactory) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The server options. + /// Optional logger factory used for logging employed by the transport. + /// is or contains a null name. + + public InMemoryTransport(IOptions serverOptions, ILoggerFactory? loggerFactory = null) + : this(GetServerName(serverOptions.Value), loggerFactory) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The name of the server. + /// Optional logger factory used for logging employed by the transport. + /// is . + /// + /// + /// By default, no logging is performed. If a is supplied, it must not log + /// to , as that will interfere with the transport's output. + /// + /// + public InMemoryTransport(string serverName, ILoggerFactory? loggerFactory = null) + { + var (clientTransport, serverTransport) = Create(serverName, loggerFactory); + ServerTransport = serverTransport; + ClientTransport = clientTransport; + } + + /// + /// Gets the client transport. + /// + public IClientTransport ClientTransport { get; } + + /// + /// Gets the server transport. + /// + public IServerTransport ServerTransport { get; } + + + private static (InMemoryClientTransport ClientTransport, InMemoryServerTransport ServerTransport) Create( + string serverName, ILoggerFactory? loggerFactory = null) { // Configure client-to-server channel - this will be used for: @@ -41,11 +94,13 @@ public static (InMemoryClientTransport ClientTransport, InMemoryServerTransport // Create the client and server transports - they directly expose the channels through MessageReader var serverTransport = new InMemoryServerTransport( + serverName, loggerFactory, clientToServerChannel.Reader, // incoming: reads messages from client serverToClientChannel.Writer); // outgoing: writes messages to client var clientTransport = new InMemoryClientTransport( + serverName, loggerFactory, clientToServerChannel.Writer, // outgoing: writes messages to server serverToClientChannel.Reader); // incoming: reads messages from server @@ -55,4 +110,12 @@ public static (InMemoryClientTransport ClientTransport, InMemoryServerTransport return (clientTransport, serverTransport); } + + private static string GetServerName(McpServerOptions serverOptions) + { + Throw.IfNull(serverOptions); + Throw.IfNull(serverOptions.ServerInfo); + + return serverOptions.ServerInfo.Name; + } } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs index ec1cd752..aa0de3ca 100644 --- a/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs @@ -10,7 +10,9 @@ public class InMemoryTransportTests(ITestOutputHelper testOutputHelper) : Logged [Fact] public async Task SendMessageAsync_Throws_Exception_If_Not_Connected() { - var (clientTransport, serverTransport) = InMemoryTransport.Create(LoggerFactory); + var transport = new InMemoryTransport("test", LoggerFactory); + var serverTransport = transport.ServerTransport; + var clientTransport = transport.ClientTransport; var message = new JsonRpcRequest { Method = "test" }; @@ -21,7 +23,9 @@ public async Task SendMessageAsync_Throws_Exception_If_Not_Connected() [Fact] public async Task DisposeAsync_Should_Dispose_Resources() { - var (clientTransport, serverTransport) = InMemoryTransport.Create(LoggerFactory); + var transport = new InMemoryTransport("test", LoggerFactory); + var serverTransport = transport.ServerTransport; + var clientTransport = transport.ClientTransport; await serverTransport.DisposeAsync(); await clientTransport.DisposeAsync(); @@ -33,7 +37,9 @@ public async Task DisposeAsync_Should_Dispose_Resources() [Fact] public async Task TransportPair_Should_Create_Valid_Transports() { - var (clientTransport, serverTransport) = InMemoryTransport.Create(LoggerFactory); + var transport = new InMemoryTransport("test", LoggerFactory); + var serverTransport = transport.ServerTransport; + var clientTransport = transport.ClientTransport; Assert.NotNull(clientTransport); Assert.NotNull(serverTransport); @@ -48,7 +54,10 @@ public async Task TransportPair_Should_Create_Valid_Transports() [Fact] public async Task Message_Should_Flow_From_Client_To_Server() { - var (clientTransport, serverTransport) = InMemoryTransport.Create(LoggerFactory); + var transport = new InMemoryTransport("test", LoggerFactory); + var serverTransport = transport.ServerTransport; + var clientTransport = transport.ClientTransport; + await clientTransport.ConnectAsync(TestContext.Current.CancellationToken); var message = new JsonRpcRequest @@ -82,7 +91,10 @@ public async Task Message_Should_Flow_From_Client_To_Server() [Fact] public async Task Message_Should_Flow_From_Server_To_Client() { - var (clientTransport, serverTransport) = InMemoryTransport.Create(LoggerFactory); + var transport = new InMemoryTransport("test", LoggerFactory); + var serverTransport = transport.ServerTransport; + var clientTransport = transport.ClientTransport; + await clientTransport.ConnectAsync(TestContext.Current.CancellationToken); var message = new JsonRpcResponse diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs index 94ae6272..4c4941ed 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs @@ -4,6 +4,7 @@ using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Utils; using ModelContextProtocol.Utils.Json; + using System.IO.Pipelines; using System.Text; using System.Text.Json; @@ -70,12 +71,12 @@ public async Task SendMessageAsync_Should_Send_Message() new Pipe().Reader.AsStream(), output, LoggerFactory); - + await transport.StartListeningAsync(TestContext.Current.CancellationToken); - + // Ensure transport is fully initialized await Task.Delay(100, TestContext.Current.CancellationToken); - + // Verify transport is connected Assert.True(transport.IsConnected, "Transport should be connected after StartListeningAsync"); @@ -124,12 +125,12 @@ public async Task ReadMessagesAsync_Should_Read_Messages() input, Stream.Null, LoggerFactory); - + await transport.StartListeningAsync(TestContext.Current.CancellationToken); - + // Ensure transport is fully initialized await Task.Delay(100, TestContext.Current.CancellationToken); - + // Verify transport is connected Assert.True(transport.IsConnected, "Transport should be connected after StartListeningAsync"); @@ -163,24 +164,24 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters() using var output = new MemoryStream(); await using var transport = new StdioServerTransport( - _serverOptions.ServerInfo.Name, + _serverOptions.ServerInfo.Name, new Pipe().Reader.AsStream(), - output, + output, LoggerFactory); - + await transport.StartListeningAsync(TestContext.Current.CancellationToken); - + // Ensure transport is fully initialized await Task.Delay(100, TestContext.Current.CancellationToken); - + // Verify transport is connected Assert.True(transport.IsConnected, "Transport should be connected after StartListeningAsync"); // Test 1: Chinese characters (BMP Unicode) var chineseText = "上下文伺服器"; // "Context Server" in Chinese - var chineseMessage = new JsonRpcRequest - { - Method = "test", + var chineseMessage = new JsonRpcRequest + { + Method = "test", Id = RequestId.FromNumber(44), Params = new Dictionary { @@ -191,18 +192,18 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters() // Clear output and send message output.SetLength(0); await transport.SendMessageAsync(chineseMessage, TestContext.Current.CancellationToken); - + // Verify Chinese characters preserved but encoded var chineseResult = Encoding.UTF8.GetString(output.ToArray()).Trim(); var expectedChinese = JsonSerializer.Serialize(chineseMessage, McpJsonUtilities.DefaultOptions); Assert.Equal(expectedChinese, chineseResult); Assert.Contains(JsonSerializer.Serialize(chineseText), chineseResult); - + // Test 2: Emoji (non-BMP Unicode using surrogate pairs) var emojiText = "🔍 🚀 👍"; // Magnifying glass, rocket, thumbs up - var emojiMessage = new JsonRpcRequest - { - Method = "test", + var emojiMessage = new JsonRpcRequest + { + Method = "test", Id = RequestId.FromNumber(45), Params = new Dictionary { @@ -213,23 +214,23 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters() // Clear output and send message output.SetLength(0); await transport.SendMessageAsync(emojiMessage, TestContext.Current.CancellationToken); - + // Verify emoji preserved - might be as either direct characters or escape sequences var emojiResult = Encoding.UTF8.GetString(output.ToArray()).Trim(); var expectedEmoji = JsonSerializer.Serialize(emojiMessage, McpJsonUtilities.DefaultOptions); Assert.Equal(expectedEmoji, emojiResult); - + // Verify surrogate pairs in different possible formats // Magnifying glass emoji: 🔍 (U+1F50D) - bool magnifyingGlassFound = - emojiResult.Contains("🔍") || + bool magnifyingGlassFound = + emojiResult.Contains("🔍") || emojiResult.Contains("\\ud83d\\udd0d", StringComparison.OrdinalIgnoreCase); - + // Rocket emoji: 🚀 (U+1F680) - bool rocketFound = - emojiResult.Contains("🚀") || + bool rocketFound = + emojiResult.Contains("🚀") || emojiResult.Contains("\\ud83d\\ude80", StringComparison.OrdinalIgnoreCase); - + Assert.True(magnifyingGlassFound, "Magnifying glass emoji not found in result"); Assert.True(rocketFound, "Rocket emoji not found in result"); } From a4d31e593d7664ded918da44703cccd9b13a5e77 Mon Sep 17 00:00:00 2001 From: Diego Colombo Date: Wed, 26 Mar 2025 05:26:41 +0000 Subject: [PATCH 10/13] work in progress --- .../Protocol/Transport/TransportTypes.cs | 5 ++ .../Server/McpServerExtensions.cs | 49 ++++++++++++++----- .../Transport/InMemoryTransportTests.cs | 31 ++++++++++++ 3 files changed, 74 insertions(+), 11 deletions(-) diff --git a/src/ModelContextProtocol/Protocol/Transport/TransportTypes.cs b/src/ModelContextProtocol/Protocol/Transport/TransportTypes.cs index 02de0a6b..8a8c9f56 100644 --- a/src/ModelContextProtocol/Protocol/Transport/TransportTypes.cs +++ b/src/ModelContextProtocol/Protocol/Transport/TransportTypes.cs @@ -14,4 +14,9 @@ public static class TransportTypes /// The name of the ServerSideEvents transport. /// public const string Sse = "sse"; + + /// + /// The name of the InMemory transport. + /// + public const string InMemory = "inmemory"; } diff --git a/src/ModelContextProtocol/Server/McpServerExtensions.cs b/src/ModelContextProtocol/Server/McpServerExtensions.cs index 73bd528d..b4b7638e 100644 --- a/src/ModelContextProtocol/Server/McpServerExtensions.cs +++ b/src/ModelContextProtocol/Server/McpServerExtensions.cs @@ -1,7 +1,13 @@ -using ModelContextProtocol.Protocol.Messages; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; + +using ModelContextProtocol.Client; +using ModelContextProtocol.Configuration; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Utils; -using Microsoft.Extensions.AI; + using System.Runtime.CompilerServices; using System.Text; @@ -10,6 +16,27 @@ namespace ModelContextProtocol.Server; /// public static class McpServerExtensions { + /// + /// Gets an in-memory client for the server. + /// + /// + /// + /// + /// + public static async Task GetInMemoryClientAsync(this IMcpServer server, CancellationToken cancellationToke = default) + { + var client = await McpClientFactory.CreateAsync( + new McpServerConfig + { + Id = server.ServerOptions.ServerInfo.Name, + Name = server.ServerOptions.ServerInfo.Name, + TransportType = TransportTypes.InMemory, + }, + createTransportFunc: (_, _) => server.Services?.GetRequiredService() ?? throw new InvalidOperationException(), + cancellationToken: cancellationToke); + return client; + } + /// /// Requests to sample an LLM via the client. /// @@ -42,7 +69,7 @@ public static Task RequestSamplingAsync( /// is . /// The client does not support sampling. public static async Task RequestSamplingAsync( - this IMcpServer server, + this IMcpServer server, IEnumerable messages, ChatOptions? options = default, CancellationToken cancellationToken = default) { Throw.IfNull(server); @@ -112,14 +139,14 @@ public static async Task RequestSamplingAsync( } var result = await server.RequestSamplingAsync(new() - { - Messages = samplingMessages, - MaxTokens = options?.MaxOutputTokens, - StopSequences = options?.StopSequences?.ToArray(), - SystemPrompt = systemPrompt?.ToString(), - Temperature = options?.Temperature, - ModelPreferences = modelPreferences, - }, cancellationToken).ConfigureAwait(false); + { + Messages = samplingMessages, + MaxTokens = options?.MaxOutputTokens, + StopSequences = options?.StopSequences?.ToArray(), + SystemPrompt = systemPrompt?.ToString(), + Temperature = options?.Temperature, + ModelPreferences = modelPreferences, + }, cancellationToken).ConfigureAwait(false); return new(new ChatMessage(new(result.Role), [result.Content.ToAIContent()])) { diff --git a/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs index aa0de3ca..92c7f230 100644 --- a/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs @@ -1,7 +1,14 @@ +using Microsoft.Extensions.DependencyInjection; + +using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Configuration; using ModelContextProtocol.Tests.Utils; +using System.Text.Json; + namespace ModelContextProtocol.Tests.Transport; public class InMemoryTransportTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper) @@ -121,4 +128,28 @@ public async Task Message_Should_Flow_From_Server_To_Client() await clientTransport.DisposeAsync(); await serverTransport.DisposeAsync(); } + + [Fact] + public async Task Can_List_Registered_Tools() + { + ServiceCollection sc = new(); + var builder = sc.AddMcpServer().WithTools().WithInMemoryServerTransport(); + var server = sc.BuildServiceProvider().GetRequiredService(); + IMcpClient client = await server.GetInMemoryClientAsync(TestContext.Current.CancellationToken); + + var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); + Assert.Equal(10, tools.Count); + + McpClientTool echoTool = tools.First(t => t.Name == "Echo"); + Assert.Equal("Echo", echoTool.Name); + Assert.Equal("Echoes the input back to the client.", echoTool.Description); + Assert.Equal("object", echoTool.JsonSchema.GetProperty("type").GetString()); + Assert.Equal(JsonValueKind.Object, echoTool.JsonSchema.GetProperty("properties").GetProperty("message").ValueKind); + Assert.Equal("the echoes message", echoTool.JsonSchema.GetProperty("properties").GetProperty("message").GetProperty("description").GetString()); + Assert.Equal(1, echoTool.JsonSchema.GetProperty("required").GetArrayLength()); + + McpClientTool doubleEchoTool = tools.First(t => t.Name == "double_echo"); + Assert.Equal("double_echo", doubleEchoTool.Name); + Assert.Equal("Echoes the input back to the client.", doubleEchoTool.Description); + } } From 46a2bd5f320e1a1f317834bc599d4d5896523fe0 Mon Sep 17 00:00:00 2001 From: Diego Colombo Date: Wed, 26 Mar 2025 09:20:42 +0000 Subject: [PATCH 11/13] wip --- .../Protocol/Transport/InMemoryClientTransport.cs | 6 +++--- .../Protocol/Transport/InMemoryServerTransport.cs | 7 +++---- .../Transport/InMemoryTransportTests.cs | 4 ++-- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/ModelContextProtocol/Protocol/Transport/InMemoryClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/InMemoryClientTransport.cs index 754309be..4a476afd 100644 --- a/src/ModelContextProtocol/Protocol/Transport/InMemoryClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/InMemoryClientTransport.cs @@ -76,7 +76,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) } _cancellationTokenSource = new CancellationTokenSource(); - _readTask = Task.Run(() => ReadMessagesAsync(_cancellationTokenSource.Token), _cancellationTokenSource.Token); + _readTask = Task.Run(() => ReadMessagesAsync(_cancellationTokenSource.Token).ConfigureAwait(false), CancellationToken.None); SetConnected(true); } @@ -136,9 +136,9 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken) { _logger.TransportEnteringReadMessagesLoop(EndpointName); - await foreach (var message in _incomingChannel.ReadAllAsync(cancellationToken)) + await foreach (var message in _incomingChannel.ReadAllAsync(cancellationToken).ConfigureAwait(false)) { - string id = "(no id)"; + var id = "(no id)"; if (message is IJsonRpcMessageWithId messageWithId) { id = messageWithId.Id.ToString(); diff --git a/src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs index 9c82e2b3..be5af42f 100644 --- a/src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs @@ -54,8 +54,7 @@ public async Task StartListeningAsync(CancellationToken cancellationToken = defa if (IsConnected) { - _logger.TransportAlreadyConnected(EndpointName); - throw new McpTransportException("Transport is already connected"); + return; } _logger.TransportConnecting(EndpointName); @@ -63,7 +62,7 @@ public async Task StartListeningAsync(CancellationToken cancellationToken = defa try { _cancellationTokenSource = new CancellationTokenSource(); - _readTask = Task.Run(() => ReadMessagesAsync(_cancellationTokenSource.Token), _cancellationTokenSource.Token); + _readTask = Task.Run(() => ReadMessagesAsync(_cancellationTokenSource.Token).ConfigureAwait(false), CancellationToken.None); SetConnected(true); } @@ -123,7 +122,7 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken) { _logger.TransportEnteringReadMessagesLoop(EndpointName); - await foreach (var message in _incomingChannel.ReadAllAsync(cancellationToken)) + await foreach (var message in _incomingChannel.ReadAllAsync(cancellationToken).ConfigureAwait(false)) { string id = "(no id)"; if (message is IJsonRpcMessageWithId messageWithId) diff --git a/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs index 92c7f230..0c9a871c 100644 --- a/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs @@ -76,7 +76,7 @@ public async Task Message_Should_Flow_From_Client_To_Server() await clientTransport.SendMessageAsync(message, TestContext.Current.CancellationToken); - await Task.Delay(1, TestContext.Current.CancellationToken); + await Task.Delay(2, TestContext.Current.CancellationToken); Assert.True(serverTransport.MessageReader.TryRead(out var receivedMessage)); @@ -112,7 +112,7 @@ public async Task Message_Should_Flow_From_Server_To_Client() await serverTransport.SendMessageAsync(message, TestContext.Current.CancellationToken); - await Task.Delay(1, TestContext.Current.CancellationToken); + await Task.Delay(2, TestContext.Current.CancellationToken); Assert.True(clientTransport.MessageReader.TryRead(out var receivedMessage)); Assert.NotNull(receivedMessage); From 5102f757a8cc4fddea61ba6906040e9bb4d56c9f Mon Sep 17 00:00:00 2001 From: Diego Colombo Date: Wed, 26 Mar 2025 12:06:13 +0000 Subject: [PATCH 12/13] start server on test --- .../Transport/InMemoryTransportTests.cs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs index 0c9a871c..084aa7cb 100644 --- a/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs @@ -135,8 +135,11 @@ public async Task Can_List_Registered_Tools() ServiceCollection sc = new(); var builder = sc.AddMcpServer().WithTools().WithInMemoryServerTransport(); var server = sc.BuildServiceProvider().GetRequiredService(); + await server.StartAsync(TestContext.Current.CancellationToken); + IMcpClient client = await server.GetInMemoryClientAsync(TestContext.Current.CancellationToken); + var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); Assert.Equal(10, tools.Count); From 82d039d83a0f8d81b8935b182baf39413038bc8e Mon Sep 17 00:00:00 2001 From: Diego Colombo Date: Wed, 26 Mar 2025 19:54:28 +0000 Subject: [PATCH 13/13] fix test post merge --- .../Transport/InMemoryTransportTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs index 084aa7cb..73fa3101 100644 --- a/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/InMemoryTransportTests.cs @@ -141,7 +141,7 @@ public async Task Can_List_Registered_Tools() var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(10, tools.Count); + Assert.Equal(11, tools.Count); McpClientTool echoTool = tools.First(t => t.Name == "Echo"); Assert.Equal("Echo", echoTool.Name);