diff --git a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs index a59de8ce..85357c7b 100644 --- a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs @@ -350,5 +350,37 @@ public static IMcpServerBuilder WithStdioServerTransport(this IMcpServerBuilder return builder; } + /// + /// Adds a server transport that uses TCP connection for communication. + /// + /// The builder instance. + /// The options to configure the transport. + public static IMcpServerBuilder WithTcpServerTransport(this IMcpServerBuilder builder, Action? configureOptions = null) + { + Throw.IfNull(builder); + + if (configureOptions != null) + { + builder.Services.Configure(configureOptions); + } + + builder.Services.AddSingleton(services => + { + McpServerTcpTransportOptions transportOptions = services.GetRequiredService>().Value; + return new TcpServerTransport(transportOptions); + }); + builder.Services.AddHostedService(); + + builder.Services.AddSingleton(services => + { + ITransport serverTransport = services.GetRequiredService(); + IOptions options = services.GetRequiredService>(); + ILoggerFactory? loggerFactory = services.GetService(); + + return McpServerFactory.Create(serverTransport, options.Value, loggerFactory, services); + }); + + return builder; + } #endregion } diff --git a/src/ModelContextProtocol/Hosting/TcpMcpServerHostedService.cs b/src/ModelContextProtocol/Hosting/TcpMcpServerHostedService.cs new file mode 100644 index 00000000..22c62c33 --- /dev/null +++ b/src/ModelContextProtocol/Hosting/TcpMcpServerHostedService.cs @@ -0,0 +1,13 @@ +using Microsoft.Extensions.Hosting; +using ModelContextProtocol.Server; + +namespace ModelContextProtocol.Hosting; + +/// +/// Hosted service for a single-session (i.e TCP) MCP server. +/// +internal sealed class TcpMcpServerHostedService(IMcpServer session) : BackgroundService +{ + /// + protected override Task ExecuteAsync(CancellationToken stoppingToken) => session.RunAsync(stoppingToken); +} diff --git a/src/ModelContextProtocol/Protocol/Transport/TcpServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/TcpServerTransport.cs new file mode 100644 index 00000000..7cd78521 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Transport/TcpServerTransport.cs @@ -0,0 +1,217 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Logging; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Server; +using ModelContextProtocol.Utils; +using ModelContextProtocol.Utils.Json; +using System.Net; +using System.Net.Sockets; +using System.Text; +using System.Text.Json; + +namespace ModelContextProtocol.Protocol.Transport; + +/// +/// Provides a server MCP transport implemented around a TCP connection. +/// +public class TcpServerTransport : TransportBase, ITransport +{ + private static readonly byte[] s_newlineBytes = "\n"u8.ToArray(); + + private readonly ILogger _logger; + private readonly TcpListener _tcpListener; + private readonly string _endpointName; + + private readonly SemaphoreSlim _sendLock = new(1, 1); + private readonly CancellationTokenSource _shutdownCts = new(); + + private readonly Task _readLoopCompleted; + private NetworkStream? _networkStream; + private StreamReader? _inputReader; + private Stream? _outputStream; + private int _disposed = 0; + + /// + /// Initializes a new instance of the class. + /// + /// Configuration options for the transport. + /// Optional logger factory used for logging employed by the transport. + public TcpServerTransport(McpServerTcpTransportOptions options, ILoggerFactory? loggerFactory = null) + : base(loggerFactory) + { + _logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance; + + _tcpListener = new TcpListener(options.IpAddress, options.Port); + _tcpListener.Start(); + + _endpointName = $"Server (TCP) ({options.IpAddress})"; + _readLoopCompleted = Task.Run(AcceptAndReadMessagesAsync, _shutdownCts.Token); + } + + /// + public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + { + if (!IsConnected) + { + _logger.TransportNotConnected(_endpointName); + throw new McpTransportException("Transport is not connected"); + } + + using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false); + + string id = "(no id)"; + if (message is IJsonRpcMessageWithId messageWithId) + { + id = messageWithId.Id.ToString(); + } + + try + { + _logger.TransportSendingMessage(_endpointName, id); + + await JsonSerializer.SerializeAsync(_outputStream!, message, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)), cancellationToken).ConfigureAwait(false); + await _outputStream!.WriteAsync(s_newlineBytes, cancellationToken).ConfigureAwait(false); + await _outputStream!.FlushAsync(cancellationToken).ConfigureAwait(false); + + _logger.TransportSentMessage(_endpointName, id); + } + catch (Exception ex) + { + _logger.TransportSendFailed(_endpointName, id, ex); + throw new McpTransportException("Failed to send message", ex); + } + } + + private async Task AcceptAndReadMessagesAsync() + { + CancellationToken shutdownToken = _shutdownCts.Token; + try + { + _logger.TransportEnteringReadMessagesLoop(_endpointName); + + while (!shutdownToken.IsCancellationRequested) + { + _logger.TransportReadingMessages(_endpointName); + + var client = await _tcpListener.AcceptTcpClientAsync().ConfigureAwait(false); + _networkStream = client.GetStream(); + _inputReader = new StreamReader(_networkStream, Encoding.UTF8); + _outputStream = _networkStream; + + SetConnected(true); + _logger.TransportAlreadyConnected(_endpointName); + + while (!shutdownToken.IsCancellationRequested) + { + var line = await _inputReader.ReadLineAsync(shutdownToken).ConfigureAwait(false); + if (string.IsNullOrWhiteSpace(line)) + { + if (line is null) + { + _logger.TransportEndOfStream(_endpointName); + break; + } + + continue; + } + + _logger.TransportReceivedMessage(_endpointName, line); + + try + { + if (JsonSerializer.Deserialize(line, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage))) is IJsonRpcMessage message) + { + string messageId = "(no id)"; + if (message is IJsonRpcMessageWithId messageWithId) + { + messageId = messageWithId.Id.ToString(); + } + _logger.TransportReceivedMessageParsed(_endpointName, messageId); + + await WriteMessageAsync(message, shutdownToken).ConfigureAwait(false); + _logger.TransportMessageWritten(_endpointName, messageId); + } + else + { + _logger.TransportMessageParseUnexpectedType(_endpointName, line); + } + } + catch (JsonException ex) + { + _logger.TransportMessageParseFailed(_endpointName, line, ex); + // Continue reading even if we fail to parse a message + } + } + + client.Close(); + SetConnected(false); + } + + _logger.TransportExitingReadMessagesLoop(_endpointName); + } + catch (OperationCanceledException) + { + _logger.TransportReadMessagesCancelled(_endpointName); + } + catch (Exception ex) + { + _logger.TransportReadMessagesFailed(_endpointName, ex); + } + finally + { + SetConnected(false); + } + } + + /// + public override async ValueTask DisposeAsync() + { + if (Interlocked.Exchange(ref _disposed, 1) != 0) + { + return; + } + + try + { + _logger.TransportCleaningUp(_endpointName); + + // Signal to the read loop to stop. + await _shutdownCts.CancelAsync().ConfigureAwait(false); + _shutdownCts.Dispose(); + + // Dispose of network resources. + _inputReader?.Dispose(); + _outputStream?.Dispose(); + _networkStream?.Dispose(); + _tcpListener.Stop(); + + // Make sure the work has quiesced. + try + { + _logger.TransportWaitingForReadTask(_endpointName); + await _readLoopCompleted.ConfigureAwait(false); + _logger.TransportReadTaskCleanedUp(_endpointName); + } + catch (TimeoutException) + { + _logger.TransportCleanupReadTaskTimeout(_endpointName); + } + catch (OperationCanceledException) + { + _logger.TransportCleanupReadTaskCancelled(_endpointName); + } + catch (Exception ex) + { + _logger.TransportCleanupReadTaskFailed(_endpointName, ex); + } + } + finally + { + SetConnected(false); + _logger.TransportCleanedUp(_endpointName); + } + + GC.SuppressFinalize(this); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Server/McpServerTcpTransportOptions.cs b/src/ModelContextProtocol/Server/McpServerTcpTransportOptions.cs new file mode 100644 index 00000000..a3b94935 --- /dev/null +++ b/src/ModelContextProtocol/Server/McpServerTcpTransportOptions.cs @@ -0,0 +1,19 @@ +using System.Net; + +namespace ModelContextProtocol.Server; + +/// +/// Configuration options for the TcpServerTransport. +/// +public class McpServerTcpTransportOptions +{ + /// + /// The TCP port to listen on. + /// + public required int Port { get; set; } = 60606; + + /// + /// The TCP host to listen on. This is typically the IP address of the server. If not specified, the server will listen on all available network interfaces. + /// + public required IPAddress IpAddress { get; set; } = IPAddress.Any; +} diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsTransportsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsTransportsTests.cs index 93d546b0..0d738aa3 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsTransportsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsTransportsTests.cs @@ -1,6 +1,7 @@ using ModelContextProtocol.Protocol.Transport; using Microsoft.Extensions.DependencyInjection; using Moq; +using System.Net; namespace ModelContextProtocol.Tests.Configuration; @@ -19,4 +20,34 @@ public void WithStdioServerTransport_Sets_Transport() Assert.NotNull(transportType); Assert.Equal(typeof(StdioServerTransport), transportType.ImplementationType); } + [Fact] + public void WithTcpServerTransport_Sets_Transport() + { + var services = new ServiceCollection(); + var builder = new Mock(); + builder.SetupGet(b => b.Services).Returns(services); + + builder.Object.WithTcpServerTransport(); + + var transportType = services.FirstOrDefault(s => s.ServiceType == typeof(ITransport)); + Assert.NotNull(transportType); + Assert.Equal(typeof(TcpServerTransport), transportType.ImplementationType); + } + [Fact] + public void WithTcpServerTransport_Sets_Transport_Options() + { + var services = new ServiceCollection(); + var builder = new Mock(); + builder.SetupGet(b => b.Services).Returns(services); + + builder.Object.WithTcpServerTransport(options => + { + options.Port = 12345; + options.IpAddress = IPAddress.Parse("127.0.0.1"); + }); + + var transportType = services.FirstOrDefault(s => s.ServiceType == typeof(ITransport)); + Assert.NotNull(transportType); + Assert.Equal(typeof(TcpServerTransport), transportType.ImplementationType); + } }