Skip to content

Commit 1f51437

Browse files
committed
Add InMemoryTransport
1 parent d47c834 commit 1f51437

File tree

4 files changed

+180
-80
lines changed

4 files changed

+180
-80
lines changed

Diff for: src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs

+52-6
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
using ModelContextProtocol.Configuration;
1+
using System.Diagnostics.CodeAnalysis;
2+
using Microsoft.Extensions.DependencyInjection;
3+
using ModelContextProtocol.Configuration;
24
using ModelContextProtocol.Hosting;
5+
using ModelContextProtocol.Protocol.Messages;
36
using ModelContextProtocol.Protocol.Transport;
47
using ModelContextProtocol.Utils;
5-
using Microsoft.Extensions.DependencyInjection;
68

79
namespace ModelContextProtocol;
810

@@ -16,23 +18,67 @@ public static partial class McpServerBuilderExtensions
1618
/// </summary>
1719
/// <param name="builder">The builder instance.</param>
1820
public static IMcpServerBuilder WithStdioServerTransport(this IMcpServerBuilder builder)
21+
{
22+
return builder.WithServerTransport<StdioServerTransport>();
23+
}
24+
25+
/// <summary>
26+
/// Adds a server transport that uses SSE via a HttpListener for communication.
27+
/// </summary>
28+
/// <param name="builder">The builder instance.</param>
29+
public static IMcpServerBuilder WithHttpListenerSseServerTransport(this IMcpServerBuilder builder)
30+
{
31+
return builder.WithServerTransport<HttpListenerSseServerTransport>();
32+
}
33+
34+
/// <summary>
35+
/// Adds a server transport for in-memory communication.
36+
/// </summary>
37+
/// <param name="builder">The builder instance.</param>
38+
public static IMcpServerBuilder WithInMemoryServerTransport(this IMcpServerBuilder builder)
39+
{
40+
return builder.WithServerTransport<InMemoryServerTransport>();
41+
}
42+
43+
/// <summary>
44+
/// Adds a server transport for in-memory communication.
45+
/// </summary>
46+
/// <param name="builder">The builder instance.</param>
47+
/// <param name="handleMessageDelegate">Delegate to handle messages.</param>
48+
public static IMcpServerBuilder WithInMemoryServerTransport(this IMcpServerBuilder builder, Func<IJsonRpcMessage, CancellationToken, Task<IJsonRpcMessage?>> handleMessageDelegate)
49+
{
50+
var transport = new InMemoryServerTransport
51+
{
52+
HandleMessage = handleMessageDelegate
53+
};
54+
55+
return builder.WithServerTransport(transport);
56+
}
57+
58+
/// <summary>
59+
/// Adds a server transport for communication.
60+
/// </summary>
61+
/// <typeparam name="TTransport">The type of the server transport to use.</typeparam>
62+
/// <param name="builder">The builder instance.</param>
63+
public static IMcpServerBuilder WithServerTransport<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] TTransport>(this IMcpServerBuilder builder) where TTransport : class, IServerTransport
1964
{
2065
Throw.IfNull(builder);
2166

22-
builder.Services.AddSingleton<IServerTransport, StdioServerTransport>();
67+
builder.Services.AddSingleton<IServerTransport, TTransport>();
2368
builder.Services.AddHostedService<McpServerHostedService>();
2469
return builder;
2570
}
2671

2772
/// <summary>
28-
/// Adds a server transport that uses SSE via a HttpListener for communication.
73+
/// Adds a server transport for communication.
2974
/// </summary>
75+
/// <param name="serverTransport">Instance of the server transport.</param>
3076
/// <param name="builder">The builder instance.</param>
31-
public static IMcpServerBuilder WithHttpListenerSseServerTransport(this IMcpServerBuilder builder)
77+
public static IMcpServerBuilder WithServerTransport(this IMcpServerBuilder builder, IServerTransport serverTransport)
3278
{
3379
Throw.IfNull(builder);
3480

35-
builder.Services.AddSingleton<IServerTransport, HttpListenerSseServerTransport>();
81+
builder.Services.AddSingleton<IServerTransport>(serverTransport);
3682
builder.Services.AddHostedService<McpServerHostedService>();
3783
return builder;
3884
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
using System.Threading.Channels;
2+
using ModelContextProtocol.Protocol.Messages;
3+
4+
namespace ModelContextProtocol.Protocol.Transport;
5+
6+
/// <summary>
7+
/// InMemory server transport for special scenarios or testing.
8+
/// </summary>
9+
public class InMemoryServerTransport : IServerTransport
10+
{
11+
private readonly Channel<IJsonRpcMessage> _messageChannel;
12+
private bool _isStarted;
13+
14+
/// <inheritdoc/>
15+
public bool IsConnected => _isStarted;
16+
17+
/// <inheritdoc/>
18+
public ChannelReader<IJsonRpcMessage> MessageReader => _messageChannel;
19+
20+
/// <summary>
21+
/// Delegate to handle messages before sending them.
22+
/// </summary>
23+
public Func<IJsonRpcMessage, CancellationToken, Task<IJsonRpcMessage?>>? HandleMessage { get; set; }
24+
25+
/// <summary>
26+
/// Initializes a new instance of the <see cref="InMemoryServerTransport"/> class.
27+
/// </summary>
28+
public InMemoryServerTransport()
29+
{
30+
_messageChannel = Channel.CreateUnbounded<IJsonRpcMessage>(new UnboundedChannelOptions
31+
{
32+
SingleReader = true,
33+
SingleWriter = true,
34+
});
35+
36+
// default message handler
37+
HandleMessage = (m, _) => Task.FromResult(CreateResponseMessage(m));
38+
}
39+
40+
/// <inheritdoc/>
41+
#if NET8_0_OR_GREATER
42+
public ValueTask DisposeAsync() => ValueTask.CompletedTask;
43+
#else
44+
public ValueTask DisposeAsync() => new ValueTask(Task.CompletedTask);
45+
#endif
46+
47+
/// <inheritdoc/>
48+
public virtual async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
49+
{
50+
IJsonRpcMessage? response = message;
51+
52+
if (HandleMessage != null)
53+
response = await HandleMessage(message, cancellationToken);
54+
55+
if (response != null)
56+
await WriteMessageAsync(response, cancellationToken);
57+
}
58+
59+
/// <inheritdoc/>
60+
public virtual Task StartListeningAsync(CancellationToken cancellationToken = default)
61+
{
62+
_isStarted = true;
63+
return Task.CompletedTask;
64+
}
65+
66+
/// <summary>
67+
/// Writes a message to the channel.
68+
/// </summary>
69+
protected virtual async Task WriteMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
70+
{
71+
await _messageChannel.Writer.WriteAsync(message, cancellationToken);
72+
}
73+
74+
/// <summary>
75+
/// Creates a response message for the given request.
76+
/// </summary>
77+
/// <param name="message"></param>
78+
/// <returns></returns>
79+
protected virtual IJsonRpcMessage? CreateResponseMessage(IJsonRpcMessage message)
80+
{
81+
if (message is JsonRpcRequest request)
82+
{
83+
return new JsonRpcResponse
84+
{
85+
Id = request.Id,
86+
Result = CreateMessageResult(request)
87+
};
88+
}
89+
90+
return message;
91+
}
92+
93+
/// <summary>
94+
/// Creates a result object for the given request.
95+
/// </summary>
96+
/// <param name="request"></param>
97+
/// <returns></returns>
98+
protected virtual object? CreateMessageResult(JsonRpcRequest request)
99+
{
100+
return null;
101+
}
102+
}

Diff for: tests/ModelContextProtocol.Tests/Server/McpServerTests.cs

+14-12
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using Microsoft.Extensions.AI;
1+
using System.Reflection;
2+
using Microsoft.Extensions.AI;
23
using Microsoft.Extensions.DependencyInjection;
34
using Microsoft.Extensions.Logging;
45
using ModelContextProtocol.Client;
@@ -8,7 +9,6 @@
89
using ModelContextProtocol.Server;
910
using ModelContextProtocol.Tests.Utils;
1011
using Moq;
11-
using System.Reflection;
1212

1313
namespace ModelContextProtocol.Tests.Server;
1414

@@ -132,9 +132,9 @@ public async Task StartAsync_Sets_Initialized_After_Transport_Responses_Initiali
132132

133133
// Send initialized notification
134134
await transport.SendMessageAsync(new JsonRpcNotification
135-
{
136-
Method = "notifications/initialized"
137-
}, TestContext.Current.CancellationToken);
135+
{
136+
Method = "notifications/initialized"
137+
}, TestContext.Current.CancellationToken);
138138

139139
await Task.Delay(50, TestContext.Current.CancellationToken);
140140

@@ -389,7 +389,7 @@ await Can_Handle_Requests(
389389
},
390390
ListResourcesHandler = (request, ct) => throw new NotImplementedException(),
391391
}
392-
},
392+
},
393393
method: "resources/read",
394394
configureOptions: null,
395395
assertResult: response =>
@@ -450,7 +450,7 @@ public async Task Can_Handle_List_Prompts_Requests_Throws_Exception_If_No_Handle
450450
public async Task Can_Handle_Get_Prompts_Requests()
451451
{
452452
await Can_Handle_Requests(
453-
new ServerCapabilities
453+
new ServerCapabilities
454454
{
455455
Prompts = new()
456456
{
@@ -479,7 +479,7 @@ public async Task Can_Handle_Get_Prompts_Requests_Throws_Exception_If_No_Handler
479479
public async Task Can_Handle_List_Tools_Requests()
480480
{
481481
await Can_Handle_Requests(
482-
new ServerCapabilities
482+
new ServerCapabilities
483483
{
484484
Tools = new()
485485
{
@@ -528,7 +528,7 @@ await Can_Handle_Requests(
528528
},
529529
ListToolsHandler = (request, ct) => throw new NotImplementedException(),
530530
}
531-
},
531+
},
532532
method: "tools/call",
533533
configureOptions: null,
534534
assertResult: response =>
@@ -559,10 +559,12 @@ private async Task Can_Handle_Requests(ServerCapabilities? serverCapabilities, s
559559

560560
var receivedMessage = new TaskCompletionSource<JsonRpcResponse>();
561561

562-
transport.OnMessageSent = (message) =>
562+
transport.HandleMessage = (message, _) =>
563563
{
564564
if (message is JsonRpcResponse response && response.Id.AsNumber == 55)
565565
receivedMessage.SetResult(response);
566+
567+
return Task.FromResult((IJsonRpcMessage?)message);
566568
};
567569

568570
await transport.SendMessageAsync(
@@ -582,7 +584,7 @@ await transport.SendMessageAsync(
582584

583585
private async Task Throws_Exception_If_No_Handler_Assigned(ServerCapabilities serverCapabilities, string method, string expectedError)
584586
{
585-
await using var transport = new TestServerTransport();
587+
await using var transport = new InMemoryServerTransport();
586588
var options = CreateOptions(serverCapabilities);
587589

588590
Assert.Throws<McpServerException>(() => McpServerFactory.Create(transport, options, LoggerFactory, _serviceProvider));
@@ -680,7 +682,7 @@ public Task<T> SendRequestAsync<T>(JsonRpcRequest request, CancellationToken can
680682
public Implementation? ClientInfo => throw new NotImplementedException();
681683
public McpServerOptions ServerOptions => throw new NotImplementedException();
682684
public IServiceProvider? Services => throw new NotImplementedException();
683-
public void AddNotificationHandler(string method, Func<JsonRpcNotification, Task> handler) =>
685+
public void AddNotificationHandler(string method, Func<JsonRpcNotification, Task> handler) =>
684686
throw new NotImplementedException();
685687
public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) =>
686688
throw new NotImplementedException();
+12-62
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,33 @@
1-
using System.Threading.Channels;
2-
using ModelContextProtocol.Protocol.Messages;
1+
using ModelContextProtocol.Protocol.Messages;
32
using ModelContextProtocol.Protocol.Transport;
43
using ModelContextProtocol.Protocol.Types;
54

65
namespace ModelContextProtocol.Tests.Utils;
76

8-
public class TestServerTransport : IServerTransport
7+
public class TestServerTransport : InMemoryServerTransport
98
{
10-
private readonly Channel<IJsonRpcMessage> _messageChannel;
11-
private bool _isStarted;
12-
13-
public bool IsConnected => _isStarted;
14-
15-
public ChannelReader<IJsonRpcMessage> MessageReader => _messageChannel;
16-
179
public List<IJsonRpcMessage> SentMessages { get; } = [];
1810

19-
public Action<IJsonRpcMessage>? OnMessageSent { get; set; }
20-
21-
public TestServerTransport()
22-
{
23-
_messageChannel = Channel.CreateUnbounded<IJsonRpcMessage>(new UnboundedChannelOptions
24-
{
25-
SingleReader = true,
26-
SingleWriter = true,
27-
});
28-
}
29-
30-
public ValueTask DisposeAsync() => ValueTask.CompletedTask;
31-
32-
public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
11+
public override Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
3312
{
3413
SentMessages.Add(message);
35-
if (message is JsonRpcRequest request)
36-
{
37-
if (request.Method == "roots/list")
38-
await ListRoots(request, cancellationToken);
39-
else if (request.Method == "sampling/createMessage")
40-
await Sampling(request, cancellationToken);
41-
else
42-
await WriteMessageAsync(request, cancellationToken);
43-
}
44-
else if (message is JsonRpcNotification notification)
45-
{
46-
await WriteMessageAsync(notification, cancellationToken);
47-
}
48-
49-
OnMessageSent?.Invoke(message);
50-
}
5114

52-
public Task StartListeningAsync(CancellationToken cancellationToken = default)
53-
{
54-
_isStarted = true;
55-
return Task.CompletedTask;
15+
return base.SendMessageAsync(message, cancellationToken);
5616
}
5717

58-
private async Task ListRoots(JsonRpcRequest request, CancellationToken cancellationToken)
18+
protected override object? CreateMessageResult(JsonRpcRequest request)
5919
{
60-
await WriteMessageAsync(new JsonRpcResponse
20+
if (request.Method == "roots/list")
6121
{
62-
Id = request.Id,
63-
Result = new ModelContextProtocol.Protocol.Types.ListRootsResult
22+
return new ModelContextProtocol.Protocol.Types.ListRootsResult
6423
{
6524
Roots = []
66-
}
67-
}, cancellationToken);
68-
}
25+
};
26+
}
6927

70-
private async Task Sampling(JsonRpcRequest request, CancellationToken cancellationToken)
71-
{
72-
await WriteMessageAsync(new JsonRpcResponse
73-
{
74-
Id = request.Id,
75-
Result = new CreateMessageResult { Content = new(), Model = "model", Role = "role" }
76-
}, cancellationToken);
77-
}
28+
if (request.Method == "sampling/createMessage")
29+
return new CreateMessageResult { Content = new(), Model = "model", Role = "role" };
7830

79-
protected async Task WriteMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
80-
{
81-
await _messageChannel.Writer.WriteAsync(message, cancellationToken);
31+
return base.CreateMessageResult(request);
8232
}
8333
}

0 commit comments

Comments
 (0)