Skip to content

Creating support for in memory transport #97

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ nCrunchTemp_*

*.orig

*.ncrunchsolution

*.lutconfig

.NCrunch_ModelContextProtocol/

*.ncrunchproject

# Auto-generated documentation
docs/_site
docs/api
1 change: 1 addition & 0 deletions Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
<PackageVersion Include="Microsoft.Bcl.Memory" Version="$(SystemVersion)" />
<PackageVersion Include="Microsoft.Extensions.AI" Version="$(MicrosoftExtensionsAIVersion)" />
<PackageVersion Include="Microsoft.Extensions.AI.Abstractions" Version="$(MicrosoftExtensionsAIVersion)" />
<PackageVersion Include="Microsoft.Extensions.DependencyInjection" Version="$(MicrosoftExtensionsVersion)" />
<PackageVersion Include="Microsoft.Extensions.Hosting.Abstractions" Version="$(MicrosoftExtensionsVersion)" />
<PackageVersion Include="Microsoft.Extensions.Logging.Abstractions" Version="$(MicrosoftExtensionsVersion)" />
<PackageVersion Include="System.Net.ServerSentEvents" Version="$(System10Version)" />
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using ModelContextProtocol.Configuration;
using Microsoft.Extensions.DependencyInjection;

using ModelContextProtocol.Configuration;
using ModelContextProtocol.Hosting;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Utils;
using Microsoft.Extensions.DependencyInjection;

namespace ModelContextProtocol;

Expand All @@ -11,6 +12,30 @@ namespace ModelContextProtocol;
/// </summary>
public static partial class McpServerBuilderExtensions
{
/// <summary>
/// Adds a server transport that uses in memory communication.
/// </summary>
/// <param name="builder">The builder instance.</param>
public static IMcpServerBuilder WithInMemoryServerTransport(this IMcpServerBuilder builder)
{
Throw.IfNull(builder);
builder.Services.AddSingleton<InMemoryTransport>();
builder.Services.AddSingleton<IClientTransport>(s =>
{
var transport = s.GetRequiredService<InMemoryTransport>();
return transport.ClientTransport;
});

builder.Services.AddSingleton<IServerTransport>(s =>
{
var transport = s.GetRequiredService<InMemoryTransport>();
return transport.ServerTransport;
});

builder.Services.AddHostedService<McpServerHostedService>();
return builder;
}

/// <summary>
/// Adds a server transport that uses stdin/stdout for communication.
/// </summary>
Expand Down
3 changes: 2 additions & 1 deletion src/ModelContextProtocol/ModelContextProtocol.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.AI.Abstractions"/>
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" />
<PackageReference Include="Microsoft.Extensions.AI" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection" />
<PackageReference Include="Microsoft.Extensions.Hosting.Abstractions" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" />
<PackageReference Include="System.Net.ServerSentEvents" />
Expand Down
228 changes: 228 additions & 0 deletions src/ModelContextProtocol/Protocol/Transport/InMemoryClientTransport.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;

using ModelContextProtocol.Logging;
using ModelContextProtocol.Protocol.Messages;

using System.Threading.Channels;

namespace ModelContextProtocol.Protocol.Transport;

/// <summary>
/// Provides an in-memory implementation of the MCP client transport.
/// </summary>
public sealed class InMemoryClientTransport : TransportBase, IClientTransport
{
private string EndpointName => $"Client (in memory) for ({_serverName})";
private readonly ILogger _logger;
private readonly string _serverName;
private readonly ChannelWriter<IJsonRpcMessage> _outgoingChannel;
private readonly ChannelReader<IJsonRpcMessage> _incomingChannel;
private CancellationTokenSource? _cancellationTokenSource;
private Task? _readTask;
private readonly SemaphoreSlim _connectLock = new SemaphoreSlim(1, 1);
private volatile bool _disposed;

/// <summary>
/// Gets or sets the server transport this client connects to.
/// </summary>
internal InMemoryServerTransport? ServerTransport { get; set; }

/// <summary>
/// Initializes a new instance of the <see cref="InMemoryClientTransport"/> class.
/// </summary>
/// <param name="serverName">The name of the server.</param>
/// <param name="loggerFactory">Optional logger factory for logging transport operations.</param>
/// <param name="outgoingChannel">Channel for sending messages to the server.</param>
/// <param name="incomingChannel">Channel for receiving messages from the server.</param>
internal InMemoryClientTransport(
string serverName,
ILoggerFactory? loggerFactory,
ChannelWriter<IJsonRpcMessage> outgoingChannel,
ChannelReader<IJsonRpcMessage> incomingChannel)
: base(loggerFactory)
{
_logger = loggerFactory?.CreateLogger<InMemoryClientTransport>()
?? NullLogger<InMemoryClientTransport>.Instance;
_serverName = serverName;
_outgoingChannel = outgoingChannel;
_incomingChannel = incomingChannel;
}



/// <inheritdoc/>
public async Task ConnectAsync(CancellationToken cancellationToken = default)
{
await _connectLock.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
ThrowIfDisposed();

if (IsConnected)
{
_logger.TransportAlreadyConnected(EndpointName);
throw new McpTransportException("Transport is already connected");
}

_logger.TransportConnecting(EndpointName);

try
{
// Start the server if it exists and is not already connected
if (ServerTransport != null && !ServerTransport.IsConnected)
{
await ServerTransport.StartListeningAsync(cancellationToken).ConfigureAwait(false);
}

_cancellationTokenSource = new CancellationTokenSource();
_readTask = Task.Run(() => ReadMessagesAsync(_cancellationTokenSource.Token).ConfigureAwait(false), CancellationToken.None);

SetConnected(true);
}
catch (Exception ex)
{
_logger.TransportConnectFailed(EndpointName, ex);
await CleanupAsync(cancellationToken).ConfigureAwait(false);
throw new McpTransportException("Failed to connect transport", ex);
}
}
finally
{
_connectLock.Release();
}
}

/// <inheritdoc/>
public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
{
ThrowIfDisposed();

if (!IsConnected)
{
_logger.TransportNotConnected(EndpointName);
throw new McpTransportException("Transport is not connected");
}

string id = "(no id)";
if (message is IJsonRpcMessageWithId messageWithId)
{
id = messageWithId.Id.ToString();
}

try
{
_logger.TransportSendingMessage(EndpointName, id);
await _outgoingChannel.WriteAsync(message, cancellationToken).ConfigureAwait(false);
_logger.TransportSentMessage(EndpointName, id);
}
catch (Exception ex)
{
_logger.TransportSendFailed(EndpointName, id, ex);
throw new McpTransportException("Failed to send message", ex);
}
}

/// <inheritdoc/>
public override async ValueTask DisposeAsync()
{
await CleanupAsync(CancellationToken.None).ConfigureAwait(false);
GC.SuppressFinalize(this);
}

private async Task ReadMessagesAsync(CancellationToken cancellationToken)
{
try
{
_logger.TransportEnteringReadMessagesLoop(EndpointName);

await foreach (var message in _incomingChannel.ReadAllAsync(cancellationToken).ConfigureAwait(false))
{
var id = "(no id)";
if (message is IJsonRpcMessageWithId messageWithId)
{
id = messageWithId.Id.ToString();
}

_logger.TransportReceivedMessageParsed(EndpointName, id);

// Write to the base class's message channel that's exposed via MessageReader
await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);

_logger.TransportMessageWritten(EndpointName, id);
}

_logger.TransportExitingReadMessagesLoop(EndpointName);
}
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
_logger.TransportReadMessagesCancelled(EndpointName);
// Normal shutdown
}
catch (Exception ex)
{
_logger.TransportReadMessagesFailed(EndpointName, ex);
}
}

private async Task CleanupAsync(CancellationToken cancellationToken)
{
if (_disposed)
{
return;
}

_disposed = true;
_logger.TransportCleaningUp(EndpointName);

try
{
if (_cancellationTokenSource != null)
{
await _cancellationTokenSource.CancelAsync().ConfigureAwait(false);
_cancellationTokenSource.Dispose();
_cancellationTokenSource = null;
}

if (_readTask != null)
{
try
{
_logger.TransportWaitingForReadTask(EndpointName);
await _readTask.WaitAsync(TimeSpan.FromSeconds(1), cancellationToken).ConfigureAwait(false);
}
catch (TimeoutException)
{
_logger.TransportCleanupReadTaskTimeout(EndpointName);
}
catch (OperationCanceledException)
{
_logger.TransportCleanupReadTaskCancelled(EndpointName);
}
catch (Exception ex)
{
_logger.TransportCleanupReadTaskFailed(EndpointName, ex);
}
finally
{
_readTask = null;
}
}

_connectLock.Dispose();
}
finally
{
SetConnected(false);
_logger.TransportCleanedUp(EndpointName);
}
}

private void ThrowIfDisposed()
{
if (_disposed)
{
throw new ObjectDisposedException(nameof(InMemoryClientTransport));
}
}
}
Loading