diff --git a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs
index 1f357d32..2659066d 100644
--- a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs
+++ b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs
@@ -1,8 +1,10 @@
-using ModelContextProtocol.Configuration;
+using System.Diagnostics.CodeAnalysis;
+using Microsoft.Extensions.DependencyInjection;
+using ModelContextProtocol.Configuration;
using ModelContextProtocol.Hosting;
+using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Utils;
-using Microsoft.Extensions.DependencyInjection;
namespace ModelContextProtocol;
@@ -16,23 +18,67 @@ public static partial class McpServerBuilderExtensions
///
/// The builder instance.
public static IMcpServerBuilder WithStdioServerTransport(this IMcpServerBuilder builder)
+ {
+ return builder.WithServerTransport();
+ }
+
+ ///
+ /// Adds a server transport that uses SSE via a HttpListener for communication.
+ ///
+ /// The builder instance.
+ public static IMcpServerBuilder WithHttpListenerSseServerTransport(this IMcpServerBuilder builder)
+ {
+ return builder.WithServerTransport();
+ }
+
+ ///
+ /// Adds a server transport for in-memory communication.
+ ///
+ /// The builder instance.
+ public static IMcpServerBuilder WithInMemoryServerTransport(this IMcpServerBuilder builder)
+ {
+ return builder.WithServerTransport();
+ }
+
+ ///
+ /// Adds a server transport for in-memory communication.
+ ///
+ /// The builder instance.
+ /// Delegate to handle messages.
+ public static IMcpServerBuilder WithInMemoryServerTransport(this IMcpServerBuilder builder, Func> handleMessageDelegate)
+ {
+ var transport = new InMemoryServerTransport
+ {
+ HandleMessage = handleMessageDelegate
+ };
+
+ return builder.WithServerTransport(transport);
+ }
+
+ ///
+ /// Adds a server transport for communication.
+ ///
+ /// The type of the server transport to use.
+ /// The builder instance.
+ public static IMcpServerBuilder WithServerTransport<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] TTransport>(this IMcpServerBuilder builder) where TTransport : class, IServerTransport
{
Throw.IfNull(builder);
- builder.Services.AddSingleton();
+ builder.Services.AddSingleton();
builder.Services.AddHostedService();
return builder;
}
///
- /// Adds a server transport that uses SSE via a HttpListener for communication.
+ /// Adds a server transport for communication.
///
+ /// Instance of the server transport.
/// The builder instance.
- public static IMcpServerBuilder WithHttpListenerSseServerTransport(this IMcpServerBuilder builder)
+ public static IMcpServerBuilder WithServerTransport(this IMcpServerBuilder builder, IServerTransport serverTransport)
{
Throw.IfNull(builder);
- builder.Services.AddSingleton();
+ builder.Services.AddSingleton(serverTransport);
builder.Services.AddHostedService();
return builder;
}
diff --git a/src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs
new file mode 100644
index 00000000..ba60558a
--- /dev/null
+++ b/src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs
@@ -0,0 +1,102 @@
+using System.Threading.Channels;
+using ModelContextProtocol.Protocol.Messages;
+
+namespace ModelContextProtocol.Protocol.Transport;
+
+///
+/// InMemory server transport for special scenarios or testing.
+///
+public class InMemoryServerTransport : IServerTransport
+{
+ private readonly Channel _messageChannel;
+ private bool _isStarted;
+
+ ///
+ public bool IsConnected => _isStarted;
+
+ ///
+ public ChannelReader MessageReader => _messageChannel;
+
+ ///
+ /// Delegate to handle messages before sending them.
+ ///
+ public Func>? HandleMessage { get; set; }
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ public InMemoryServerTransport()
+ {
+ _messageChannel = Channel.CreateUnbounded(new UnboundedChannelOptions
+ {
+ SingleReader = true,
+ SingleWriter = true,
+ });
+
+ // default message handler
+ HandleMessage = (m, _) => Task.FromResult(CreateResponseMessage(m));
+ }
+
+ ///
+#if NET8_0_OR_GREATER
+ public ValueTask DisposeAsync() => ValueTask.CompletedTask;
+#else
+ public ValueTask DisposeAsync() => new ValueTask(Task.CompletedTask);
+#endif
+
+ ///
+ public virtual async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ {
+ IJsonRpcMessage? response = message;
+
+ if (HandleMessage != null)
+ response = await HandleMessage(message, cancellationToken);
+
+ if (response != null)
+ await WriteMessageAsync(response, cancellationToken);
+ }
+
+ ///
+ public virtual Task StartListeningAsync(CancellationToken cancellationToken = default)
+ {
+ _isStarted = true;
+ return Task.CompletedTask;
+ }
+
+ ///
+ /// Writes a message to the channel.
+ ///
+ protected virtual async Task WriteMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ {
+ await _messageChannel.Writer.WriteAsync(message, cancellationToken);
+ }
+
+ ///
+ /// Creates a response message for the given request.
+ ///
+ ///
+ ///
+ protected virtual IJsonRpcMessage? CreateResponseMessage(IJsonRpcMessage message)
+ {
+ if (message is JsonRpcRequest request)
+ {
+ return new JsonRpcResponse
+ {
+ Id = request.Id,
+ Result = CreateMessageResult(request)
+ };
+ }
+
+ return message;
+ }
+
+ ///
+ /// Creates a result object for the given request.
+ ///
+ ///
+ ///
+ protected virtual object? CreateMessageResult(JsonRpcRequest request)
+ {
+ return null;
+ }
+}
diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs
index 3f4dd1d8..82b8e883 100644
--- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs
+++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs
@@ -1,4 +1,5 @@
-using Microsoft.Extensions.AI;
+using System.Reflection;
+using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using ModelContextProtocol.Client;
@@ -8,7 +9,6 @@
using ModelContextProtocol.Server;
using ModelContextProtocol.Tests.Utils;
using Moq;
-using System.Reflection;
namespace ModelContextProtocol.Tests.Server;
@@ -132,9 +132,9 @@ public async Task StartAsync_Sets_Initialized_After_Transport_Responses_Initiali
// Send initialized notification
await transport.SendMessageAsync(new JsonRpcNotification
- {
- Method = "notifications/initialized"
- }, TestContext.Current.CancellationToken);
+ {
+ Method = "notifications/initialized"
+ }, TestContext.Current.CancellationToken);
await Task.Delay(50, TestContext.Current.CancellationToken);
@@ -389,7 +389,7 @@ await Can_Handle_Requests(
},
ListResourcesHandler = (request, ct) => throw new NotImplementedException(),
}
- },
+ },
method: "resources/read",
configureOptions: null,
assertResult: response =>
@@ -450,7 +450,7 @@ public async Task Can_Handle_List_Prompts_Requests_Throws_Exception_If_No_Handle
public async Task Can_Handle_Get_Prompts_Requests()
{
await Can_Handle_Requests(
- new ServerCapabilities
+ new ServerCapabilities
{
Prompts = new()
{
@@ -479,7 +479,7 @@ public async Task Can_Handle_Get_Prompts_Requests_Throws_Exception_If_No_Handler
public async Task Can_Handle_List_Tools_Requests()
{
await Can_Handle_Requests(
- new ServerCapabilities
+ new ServerCapabilities
{
Tools = new()
{
@@ -528,7 +528,7 @@ await Can_Handle_Requests(
},
ListToolsHandler = (request, ct) => throw new NotImplementedException(),
}
- },
+ },
method: "tools/call",
configureOptions: null,
assertResult: response =>
@@ -559,10 +559,12 @@ private async Task Can_Handle_Requests(ServerCapabilities? serverCapabilities, s
var receivedMessage = new TaskCompletionSource();
- transport.OnMessageSent = (message) =>
+ transport.HandleMessage = (message, _) =>
{
if (message is JsonRpcResponse response && response.Id.AsNumber == 55)
receivedMessage.SetResult(response);
+
+ return Task.FromResult((IJsonRpcMessage?)message);
};
await transport.SendMessageAsync(
@@ -582,7 +584,7 @@ await transport.SendMessageAsync(
private async Task Throws_Exception_If_No_Handler_Assigned(ServerCapabilities serverCapabilities, string method, string expectedError)
{
- await using var transport = new TestServerTransport();
+ await using var transport = new InMemoryServerTransport();
var options = CreateOptions(serverCapabilities);
Assert.Throws(() => McpServerFactory.Create(transport, options, LoggerFactory, _serviceProvider));
@@ -680,7 +682,7 @@ public Task SendRequestAsync(JsonRpcRequest request, CancellationToken can
public Implementation? ClientInfo => throw new NotImplementedException();
public McpServerOptions ServerOptions => throw new NotImplementedException();
public IServiceProvider? Services => throw new NotImplementedException();
- public void AddNotificationHandler(string method, Func handler) =>
+ public void AddNotificationHandler(string method, Func handler) =>
throw new NotImplementedException();
public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) =>
throw new NotImplementedException();
diff --git a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs
index f38d933e..17167332 100644
--- a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs
+++ b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs
@@ -1,83 +1,33 @@
-using System.Threading.Channels;
-using ModelContextProtocol.Protocol.Messages;
+using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Protocol.Types;
namespace ModelContextProtocol.Tests.Utils;
-public class TestServerTransport : IServerTransport
+public class TestServerTransport : InMemoryServerTransport
{
- private readonly Channel _messageChannel;
- private bool _isStarted;
-
- public bool IsConnected => _isStarted;
-
- public ChannelReader MessageReader => _messageChannel;
-
public List SentMessages { get; } = [];
- public Action? OnMessageSent { get; set; }
-
- public TestServerTransport()
- {
- _messageChannel = Channel.CreateUnbounded(new UnboundedChannelOptions
- {
- SingleReader = true,
- SingleWriter = true,
- });
- }
-
- public ValueTask DisposeAsync() => ValueTask.CompletedTask;
-
- public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ public override Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
{
SentMessages.Add(message);
- if (message is JsonRpcRequest request)
- {
- if (request.Method == "roots/list")
- await ListRoots(request, cancellationToken);
- else if (request.Method == "sampling/createMessage")
- await Sampling(request, cancellationToken);
- else
- await WriteMessageAsync(request, cancellationToken);
- }
- else if (message is JsonRpcNotification notification)
- {
- await WriteMessageAsync(notification, cancellationToken);
- }
-
- OnMessageSent?.Invoke(message);
- }
- public Task StartListeningAsync(CancellationToken cancellationToken = default)
- {
- _isStarted = true;
- return Task.CompletedTask;
+ return base.SendMessageAsync(message, cancellationToken);
}
- private async Task ListRoots(JsonRpcRequest request, CancellationToken cancellationToken)
+ protected override object? CreateMessageResult(JsonRpcRequest request)
{
- await WriteMessageAsync(new JsonRpcResponse
+ if (request.Method == "roots/list")
{
- Id = request.Id,
- Result = new ModelContextProtocol.Protocol.Types.ListRootsResult
+ return new ModelContextProtocol.Protocol.Types.ListRootsResult
{
Roots = []
- }
- }, cancellationToken);
- }
+ };
+ }
- private async Task Sampling(JsonRpcRequest request, CancellationToken cancellationToken)
- {
- await WriteMessageAsync(new JsonRpcResponse
- {
- Id = request.Id,
- Result = new CreateMessageResult { Content = new(), Model = "model", Role = "role" }
- }, cancellationToken);
- }
+ if (request.Method == "sampling/createMessage")
+ return new CreateMessageResult { Content = new(), Model = "model", Role = "role" };
- protected async Task WriteMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
- {
- await _messageChannel.Writer.WriteAsync(message, cancellationToken);
+ return base.CreateMessageResult(request);
}
}