Skip to content

Add InMemoryTransport #140

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
using ModelContextProtocol.Configuration;
using System.Diagnostics.CodeAnalysis;
using Microsoft.Extensions.DependencyInjection;
using ModelContextProtocol.Configuration;
using ModelContextProtocol.Hosting;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Utils;
using Microsoft.Extensions.DependencyInjection;

namespace ModelContextProtocol;

Expand All @@ -16,23 +18,67 @@ public static partial class McpServerBuilderExtensions
/// </summary>
/// <param name="builder">The builder instance.</param>
public static IMcpServerBuilder WithStdioServerTransport(this IMcpServerBuilder builder)
{
return builder.WithServerTransport<StdioServerTransport>();
}

/// <summary>
/// Adds a server transport that uses SSE via a HttpListener for communication.
/// </summary>
/// <param name="builder">The builder instance.</param>
public static IMcpServerBuilder WithHttpListenerSseServerTransport(this IMcpServerBuilder builder)
{
return builder.WithServerTransport<HttpListenerSseServerTransport>();
}

/// <summary>
/// Adds a server transport for in-memory communication.
/// </summary>
/// <param name="builder">The builder instance.</param>
public static IMcpServerBuilder WithInMemoryServerTransport(this IMcpServerBuilder builder)
{
return builder.WithServerTransport<InMemoryServerTransport>();
}

/// <summary>
/// Adds a server transport for in-memory communication.
/// </summary>
/// <param name="builder">The builder instance.</param>
/// <param name="handleMessageDelegate">Delegate to handle messages.</param>
public static IMcpServerBuilder WithInMemoryServerTransport(this IMcpServerBuilder builder, Func<IJsonRpcMessage, CancellationToken, Task<IJsonRpcMessage?>> handleMessageDelegate)
{
var transport = new InMemoryServerTransport
{
HandleMessage = handleMessageDelegate
};

return builder.WithServerTransport(transport);
}

/// <summary>
/// Adds a server transport for communication.
/// </summary>
/// <typeparam name="TTransport">The type of the server transport to use.</typeparam>
/// <param name="builder">The builder instance.</param>
public static IMcpServerBuilder WithServerTransport<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] TTransport>(this IMcpServerBuilder builder) where TTransport : class, IServerTransport
{
Throw.IfNull(builder);

builder.Services.AddSingleton<IServerTransport, StdioServerTransport>();
builder.Services.AddSingleton<IServerTransport, TTransport>();
builder.Services.AddHostedService<McpServerHostedService>();
return builder;
}

/// <summary>
/// Adds a server transport that uses SSE via a HttpListener for communication.
/// Adds a server transport for communication.
/// </summary>
/// <param name="serverTransport">Instance of the server transport.</param>
/// <param name="builder">The builder instance.</param>
public static IMcpServerBuilder WithHttpListenerSseServerTransport(this IMcpServerBuilder builder)
public static IMcpServerBuilder WithServerTransport(this IMcpServerBuilder builder, IServerTransport serverTransport)
{
Throw.IfNull(builder);

builder.Services.AddSingleton<IServerTransport, HttpListenerSseServerTransport>();
builder.Services.AddSingleton<IServerTransport>(serverTransport);
builder.Services.AddHostedService<McpServerHostedService>();
return builder;
}
Expand Down
102 changes: 102 additions & 0 deletions src/ModelContextProtocol/Protocol/Transport/InMemoryServerTransport.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
using System.Threading.Channels;
using ModelContextProtocol.Protocol.Messages;

namespace ModelContextProtocol.Protocol.Transport;

/// <summary>
/// InMemory server transport for special scenarios or testing.
/// </summary>
public class InMemoryServerTransport : IServerTransport
{
private readonly Channel<IJsonRpcMessage> _messageChannel;
private bool _isStarted;

/// <inheritdoc/>
public bool IsConnected => _isStarted;

/// <inheritdoc/>
public ChannelReader<IJsonRpcMessage> MessageReader => _messageChannel;

/// <summary>
/// Delegate to handle messages before sending them.
/// </summary>
public Func<IJsonRpcMessage, CancellationToken, Task<IJsonRpcMessage?>>? HandleMessage { get; set; }

/// <summary>
/// Initializes a new instance of the <see cref="InMemoryServerTransport"/> class.
/// </summary>
public InMemoryServerTransport()
{
_messageChannel = Channel.CreateUnbounded<IJsonRpcMessage>(new UnboundedChannelOptions
{
SingleReader = true,
SingleWriter = true,
});

// default message handler
HandleMessage = (m, _) => Task.FromResult(CreateResponseMessage(m));
}

/// <inheritdoc/>
#if NET8_0_OR_GREATER
public ValueTask DisposeAsync() => ValueTask.CompletedTask;
#else
public ValueTask DisposeAsync() => new ValueTask(Task.CompletedTask);
#endif

/// <inheritdoc/>
public virtual async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
{
IJsonRpcMessage? response = message;

if (HandleMessage != null)
response = await HandleMessage(message, cancellationToken);

if (response != null)
await WriteMessageAsync(response, cancellationToken);
}

/// <inheritdoc/>
public virtual Task StartListeningAsync(CancellationToken cancellationToken = default)
{
_isStarted = true;
return Task.CompletedTask;
}

/// <summary>
/// Writes a message to the channel.
/// </summary>
protected virtual async Task WriteMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
{
await _messageChannel.Writer.WriteAsync(message, cancellationToken);
}

/// <summary>
/// Creates a response message for the given request.
/// </summary>
/// <param name="message"></param>
/// <returns></returns>
protected virtual IJsonRpcMessage? CreateResponseMessage(IJsonRpcMessage message)
{
if (message is JsonRpcRequest request)
{
return new JsonRpcResponse
{
Id = request.Id,
Result = CreateMessageResult(request)
};
}

return message;
}

/// <summary>
/// Creates a result object for the given request.
/// </summary>
/// <param name="request"></param>
/// <returns></returns>
protected virtual object? CreateMessageResult(JsonRpcRequest request)
{
return null;
}
}
26 changes: 14 additions & 12 deletions tests/ModelContextProtocol.Tests/Server/McpServerTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Microsoft.Extensions.AI;
using System.Reflection;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using ModelContextProtocol.Client;
Expand All @@ -8,7 +9,6 @@
using ModelContextProtocol.Server;
using ModelContextProtocol.Tests.Utils;
using Moq;
using System.Reflection;

namespace ModelContextProtocol.Tests.Server;

Expand Down Expand Up @@ -132,9 +132,9 @@ public async Task StartAsync_Sets_Initialized_After_Transport_Responses_Initiali

// Send initialized notification
await transport.SendMessageAsync(new JsonRpcNotification
{
Method = "notifications/initialized"
}, TestContext.Current.CancellationToken);
{
Method = "notifications/initialized"
}, TestContext.Current.CancellationToken);

await Task.Delay(50, TestContext.Current.CancellationToken);

Expand Down Expand Up @@ -389,7 +389,7 @@ await Can_Handle_Requests(
},
ListResourcesHandler = (request, ct) => throw new NotImplementedException(),
}
},
},
method: "resources/read",
configureOptions: null,
assertResult: response =>
Expand Down Expand Up @@ -450,7 +450,7 @@ public async Task Can_Handle_List_Prompts_Requests_Throws_Exception_If_No_Handle
public async Task Can_Handle_Get_Prompts_Requests()
{
await Can_Handle_Requests(
new ServerCapabilities
new ServerCapabilities
{
Prompts = new()
{
Expand Down Expand Up @@ -479,7 +479,7 @@ public async Task Can_Handle_Get_Prompts_Requests_Throws_Exception_If_No_Handler
public async Task Can_Handle_List_Tools_Requests()
{
await Can_Handle_Requests(
new ServerCapabilities
new ServerCapabilities
{
Tools = new()
{
Expand Down Expand Up @@ -528,7 +528,7 @@ await Can_Handle_Requests(
},
ListToolsHandler = (request, ct) => throw new NotImplementedException(),
}
},
},
method: "tools/call",
configureOptions: null,
assertResult: response =>
Expand Down Expand Up @@ -559,10 +559,12 @@ private async Task Can_Handle_Requests(ServerCapabilities? serverCapabilities, s

var receivedMessage = new TaskCompletionSource<JsonRpcResponse>();

transport.OnMessageSent = (message) =>
transport.HandleMessage = (message, _) =>
{
if (message is JsonRpcResponse response && response.Id.AsNumber == 55)
receivedMessage.SetResult(response);

return Task.FromResult((IJsonRpcMessage?)message);
};

await transport.SendMessageAsync(
Expand All @@ -582,7 +584,7 @@ await transport.SendMessageAsync(

private async Task Throws_Exception_If_No_Handler_Assigned(ServerCapabilities serverCapabilities, string method, string expectedError)
{
await using var transport = new TestServerTransport();
await using var transport = new InMemoryServerTransport();
var options = CreateOptions(serverCapabilities);

Assert.Throws<McpServerException>(() => McpServerFactory.Create(transport, options, LoggerFactory, _serviceProvider));
Expand Down Expand Up @@ -680,7 +682,7 @@ public Task<T> SendRequestAsync<T>(JsonRpcRequest request, CancellationToken can
public Implementation? ClientInfo => throw new NotImplementedException();
public McpServerOptions ServerOptions => throw new NotImplementedException();
public IServiceProvider? Services => throw new NotImplementedException();
public void AddNotificationHandler(string method, Func<JsonRpcNotification, Task> handler) =>
public void AddNotificationHandler(string method, Func<JsonRpcNotification, Task> handler) =>
throw new NotImplementedException();
public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) =>
throw new NotImplementedException();
Expand Down
74 changes: 12 additions & 62 deletions tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs
Original file line number Diff line number Diff line change
@@ -1,83 +1,33 @@
using System.Threading.Channels;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Protocol.Types;

namespace ModelContextProtocol.Tests.Utils;

public class TestServerTransport : IServerTransport
public class TestServerTransport : InMemoryServerTransport
{
private readonly Channel<IJsonRpcMessage> _messageChannel;
private bool _isStarted;

public bool IsConnected => _isStarted;

public ChannelReader<IJsonRpcMessage> MessageReader => _messageChannel;

public List<IJsonRpcMessage> SentMessages { get; } = [];

public Action<IJsonRpcMessage>? OnMessageSent { get; set; }

public TestServerTransport()
{
_messageChannel = Channel.CreateUnbounded<IJsonRpcMessage>(new UnboundedChannelOptions
{
SingleReader = true,
SingleWriter = true,
});
}

public ValueTask DisposeAsync() => ValueTask.CompletedTask;

public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
public override Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
{
SentMessages.Add(message);
if (message is JsonRpcRequest request)
{
if (request.Method == "roots/list")
await ListRoots(request, cancellationToken);
else if (request.Method == "sampling/createMessage")
await Sampling(request, cancellationToken);
else
await WriteMessageAsync(request, cancellationToken);
}
else if (message is JsonRpcNotification notification)
{
await WriteMessageAsync(notification, cancellationToken);
}

OnMessageSent?.Invoke(message);
}

public Task StartListeningAsync(CancellationToken cancellationToken = default)
{
_isStarted = true;
return Task.CompletedTask;
return base.SendMessageAsync(message, cancellationToken);
}

private async Task ListRoots(JsonRpcRequest request, CancellationToken cancellationToken)
protected override object? CreateMessageResult(JsonRpcRequest request)
{
await WriteMessageAsync(new JsonRpcResponse
if (request.Method == "roots/list")
{
Id = request.Id,
Result = new ModelContextProtocol.Protocol.Types.ListRootsResult
return new ModelContextProtocol.Protocol.Types.ListRootsResult
{
Roots = []
}
}, cancellationToken);
}
};
}

private async Task Sampling(JsonRpcRequest request, CancellationToken cancellationToken)
{
await WriteMessageAsync(new JsonRpcResponse
{
Id = request.Id,
Result = new CreateMessageResult { Content = new(), Model = "model", Role = "role" }
}, cancellationToken);
}
if (request.Method == "sampling/createMessage")
return new CreateMessageResult { Content = new(), Model = "model", Role = "role" };

protected async Task WriteMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
{
await _messageChannel.Writer.WriteAsync(message, cancellationToken);
return base.CreateMessageResult(request);
}
}