Skip to content

Added TcpServerTransport #197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -350,5 +350,37 @@ public static IMcpServerBuilder WithStdioServerTransport(this IMcpServerBuilder

return builder;
}
/// <summary>
/// Adds a server transport that uses TCP connection for communication.
/// </summary>
/// <param name="builder">The builder instance.</param>
/// <param name="configureOptions">The options to configure the transport.</param>
public static IMcpServerBuilder WithTcpServerTransport(this IMcpServerBuilder builder, Action<McpServerTcpTransportOptions>? configureOptions = null)
{
Throw.IfNull(builder);

if (configureOptions != null)
{
builder.Services.Configure(configureOptions);
}

builder.Services.AddSingleton<ITransport, TcpServerTransport>(services =>
{
McpServerTcpTransportOptions transportOptions = services.GetRequiredService<IOptions<McpServerTcpTransportOptions>>().Value;
return new TcpServerTransport(transportOptions);
});
builder.Services.AddHostedService<TcpMcpServerHostedService>();

builder.Services.AddSingleton(services =>
{
ITransport serverTransport = services.GetRequiredService<ITransport>();
IOptions<McpServerOptions> options = services.GetRequiredService<IOptions<McpServerOptions>>();
ILoggerFactory? loggerFactory = services.GetService<ILoggerFactory>();

return McpServerFactory.Create(serverTransport, options.Value, loggerFactory, services);
});

return builder;
}
#endregion
}
13 changes: 13 additions & 0 deletions src/ModelContextProtocol/Hosting/TcpMcpServerHostedService.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using Microsoft.Extensions.Hosting;
using ModelContextProtocol.Server;

namespace ModelContextProtocol.Hosting;

/// <summary>
/// Hosted service for a single-session (i.e TCP) MCP server.
/// </summary>
internal sealed class TcpMcpServerHostedService(IMcpServer session) : BackgroundService
{
/// <inheritdoc />
protected override Task ExecuteAsync(CancellationToken stoppingToken) => session.RunAsync(stoppingToken);
}
217 changes: 217 additions & 0 deletions src/ModelContextProtocol/Protocol/Transport/TcpServerTransport.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using ModelContextProtocol.Logging;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Server;
using ModelContextProtocol.Utils;
using ModelContextProtocol.Utils.Json;
using System.Net;
using System.Net.Sockets;
using System.Text;
using System.Text.Json;

namespace ModelContextProtocol.Protocol.Transport;

/// <summary>
/// Provides a server MCP transport implemented around a TCP connection.
/// </summary>
public class TcpServerTransport : TransportBase, ITransport
{
private static readonly byte[] s_newlineBytes = "\n"u8.ToArray();

private readonly ILogger _logger;
private readonly TcpListener _tcpListener;
private readonly string _endpointName;

private readonly SemaphoreSlim _sendLock = new(1, 1);
private readonly CancellationTokenSource _shutdownCts = new();

private readonly Task _readLoopCompleted;
private NetworkStream? _networkStream;
private StreamReader? _inputReader;
private Stream? _outputStream;
private int _disposed = 0;

/// <summary>
/// Initializes a new instance of the <see cref="TcpServerTransport"/> class.
/// </summary>
/// <param name="options">Configuration options for the transport.</param>
/// <param name="loggerFactory">Optional logger factory used for logging employed by the transport.</param>
public TcpServerTransport(McpServerTcpTransportOptions options, ILoggerFactory? loggerFactory = null)
: base(loggerFactory)
{
_logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance;

_tcpListener = new TcpListener(options.IpAddress, options.Port);
_tcpListener.Start();

_endpointName = $"Server (TCP) ({options.IpAddress})";
_readLoopCompleted = Task.Run(AcceptAndReadMessagesAsync, _shutdownCts.Token);
}

/// <inheritdoc/>
public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
{
if (!IsConnected)
{
_logger.TransportNotConnected(_endpointName);
throw new McpTransportException("Transport is not connected");
}

using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false);

string id = "(no id)";
if (message is IJsonRpcMessageWithId messageWithId)
{
id = messageWithId.Id.ToString();
}

try
{
_logger.TransportSendingMessage(_endpointName, id);

await JsonSerializer.SerializeAsync(_outputStream!, message, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)), cancellationToken).ConfigureAwait(false);
await _outputStream!.WriteAsync(s_newlineBytes, cancellationToken).ConfigureAwait(false);
await _outputStream!.FlushAsync(cancellationToken).ConfigureAwait(false);

_logger.TransportSentMessage(_endpointName, id);
}
catch (Exception ex)
{
_logger.TransportSendFailed(_endpointName, id, ex);
throw new McpTransportException("Failed to send message", ex);
}
}

private async Task AcceptAndReadMessagesAsync()
{
CancellationToken shutdownToken = _shutdownCts.Token;
try
{
_logger.TransportEnteringReadMessagesLoop(_endpointName);

while (!shutdownToken.IsCancellationRequested)
{
_logger.TransportReadingMessages(_endpointName);

var client = await _tcpListener.AcceptTcpClientAsync().ConfigureAwait(false);
_networkStream = client.GetStream();
_inputReader = new StreamReader(_networkStream, Encoding.UTF8);
_outputStream = _networkStream;

SetConnected(true);
_logger.TransportAlreadyConnected(_endpointName);

while (!shutdownToken.IsCancellationRequested)
{
var line = await _inputReader.ReadLineAsync(shutdownToken).ConfigureAwait(false);
if (string.IsNullOrWhiteSpace(line))
{
if (line is null)
{
_logger.TransportEndOfStream(_endpointName);
break;
}

continue;
}

_logger.TransportReceivedMessage(_endpointName, line);

try
{
if (JsonSerializer.Deserialize(line, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage))) is IJsonRpcMessage message)
{
string messageId = "(no id)";
if (message is IJsonRpcMessageWithId messageWithId)
{
messageId = messageWithId.Id.ToString();
}
_logger.TransportReceivedMessageParsed(_endpointName, messageId);

await WriteMessageAsync(message, shutdownToken).ConfigureAwait(false);
_logger.TransportMessageWritten(_endpointName, messageId);
}
else
{
_logger.TransportMessageParseUnexpectedType(_endpointName, line);
}
}
catch (JsonException ex)
{
_logger.TransportMessageParseFailed(_endpointName, line, ex);
// Continue reading even if we fail to parse a message
}
}

client.Close();
SetConnected(false);
}

_logger.TransportExitingReadMessagesLoop(_endpointName);
}
catch (OperationCanceledException)
{
_logger.TransportReadMessagesCancelled(_endpointName);
}
catch (Exception ex)
{
_logger.TransportReadMessagesFailed(_endpointName, ex);
}
finally
{
SetConnected(false);
}
}

/// <inheritdoc />
public override async ValueTask DisposeAsync()
{
if (Interlocked.Exchange(ref _disposed, 1) != 0)
{
return;
}

try
{
_logger.TransportCleaningUp(_endpointName);

// Signal to the read loop to stop.
await _shutdownCts.CancelAsync().ConfigureAwait(false);
_shutdownCts.Dispose();

// Dispose of network resources.
_inputReader?.Dispose();
_outputStream?.Dispose();
_networkStream?.Dispose();
_tcpListener.Stop();

// Make sure the work has quiesced.
try
{
_logger.TransportWaitingForReadTask(_endpointName);
await _readLoopCompleted.ConfigureAwait(false);
_logger.TransportReadTaskCleanedUp(_endpointName);
}
catch (TimeoutException)
{
_logger.TransportCleanupReadTaskTimeout(_endpointName);
}
catch (OperationCanceledException)
{
_logger.TransportCleanupReadTaskCancelled(_endpointName);
}
catch (Exception ex)
{
_logger.TransportCleanupReadTaskFailed(_endpointName, ex);
}
}
finally
{
SetConnected(false);
_logger.TransportCleanedUp(_endpointName);
}

GC.SuppressFinalize(this);
}
}
19 changes: 19 additions & 0 deletions src/ModelContextProtocol/Server/McpServerTcpTransportOptions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using System.Net;

namespace ModelContextProtocol.Server;

/// <summary>
/// Configuration options for the TcpServerTransport.
/// </summary>
public class McpServerTcpTransportOptions
{
/// <summary>
/// The TCP port to listen on.
/// </summary>
public required int Port { get; set; } = 60606;

/// <summary>
/// The TCP host to listen on. This is typically the IP address of the server. If not specified, the server will listen on all available network interfaces.
/// </summary>
public required IPAddress IpAddress { get; set; } = IPAddress.Any;
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using ModelContextProtocol.Protocol.Transport;
using Microsoft.Extensions.DependencyInjection;
using Moq;
using System.Net;

namespace ModelContextProtocol.Tests.Configuration;

Expand All @@ -19,4 +20,34 @@ public void WithStdioServerTransport_Sets_Transport()
Assert.NotNull(transportType);
Assert.Equal(typeof(StdioServerTransport), transportType.ImplementationType);
}
[Fact]
public void WithTcpServerTransport_Sets_Transport()
{
var services = new ServiceCollection();
var builder = new Mock<IMcpServerBuilder>();
builder.SetupGet(b => b.Services).Returns(services);

builder.Object.WithTcpServerTransport();

var transportType = services.FirstOrDefault(s => s.ServiceType == typeof(ITransport));
Assert.NotNull(transportType);
Assert.Equal(typeof(TcpServerTransport), transportType.ImplementationType);
}
[Fact]
public void WithTcpServerTransport_Sets_Transport_Options()
{
var services = new ServiceCollection();
var builder = new Mock<IMcpServerBuilder>();
builder.SetupGet(b => b.Services).Returns(services);

builder.Object.WithTcpServerTransport(options =>
{
options.Port = 12345;
options.IpAddress = IPAddress.Parse("127.0.0.1");
});

var transportType = services.FirstOrDefault(s => s.ServiceType == typeof(ITransport));
Assert.NotNull(transportType);
Assert.Equal(typeof(TcpServerTransport), transportType.ImplementationType);
}
}