diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs index 96259baf..bb96356b 100644 --- a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs @@ -47,11 +47,11 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo { throw new Exception($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created."); } - await using var server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider); try { var transportTask = transport.RunAsync(cancellationToken: requestAborted); + await using var server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider); try { @@ -85,7 +85,7 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo if (!_sessions.TryGetValue(sessionId.ToString(), out var transport)) { - await Results.BadRequest($"Session {sessionId} not found.").ExecuteAsync(context); + await Results.BadRequest($"Session ID not found.").ExecuteAsync(context); return; } diff --git a/src/ModelContextProtocol/Client/McpClient.cs b/src/ModelContextProtocol/Client/McpClient.cs index 773301c0..2b35e515 100644 --- a/src/ModelContextProtocol/Client/McpClient.cs +++ b/src/ModelContextProtocol/Client/McpClient.cs @@ -82,29 +82,27 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) { // Connect transport _sessionTransport = await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false); - // We don't want the ConnectAsync token to cancel the session after we've successfully connected. - // The base class handles cleaning up the session in DisposeAsync without our help. - StartSession(_sessionTransport, fullSessionCancellationToken: CancellationToken.None); + StartSession(_sessionTransport); // Perform initialization sequence using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); initializationCts.CancelAfter(_options.InitializationTimeout); - try - { - // Send initialize request - var initializeResponse = await SendRequestAsync( - new JsonRpcRequest - { - Method = RequestMethods.Initialize, - Params = new InitializeRequestParams() + try + { + // Send initialize request + var initializeResponse = await SendRequestAsync( + new JsonRpcRequest { - ProtocolVersion = _options.ProtocolVersion, - Capabilities = _options.Capabilities ?? new ClientCapabilities(), - ClientInfo = _options.ClientInfo - } - }, - initializationCts.Token).ConfigureAwait(false); + Method = RequestMethods.Initialize, + Params = new InitializeRequestParams() + { + ProtocolVersion = _options.ProtocolVersion, + Capabilities = _options.Capabilities ?? new ClientCapabilities(), + ClientInfo = _options.ClientInfo + } + }, + initializationCts.Token).ConfigureAwait(false); // Store server information _logger.ServerCapabilitiesReceived(EndpointName, diff --git a/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs index eafd1f61..d375aa66 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs @@ -32,19 +32,6 @@ public sealed class SseResponseStreamTransport(Stream sseResponseStream, string /// A task representing the send loop that writes JSON-RPC messages to the SSE response stream. public Task RunAsync(CancellationToken cancellationToken) { - void WriteJsonRpcMessageToBuffer(SseItem item, IBufferWriter writer) - { - if (item.EventType == "endpoint") - { - writer.Write(Encoding.UTF8.GetBytes(messageEndpoint)); - return; - } - - JsonSerializer.Serialize(GetUtf8JsonWriter(writer), item.Data, McpJsonUtilities.DefaultOptions.GetTypeInfo()); - } - - IsConnected = true; - // The very first SSE event isn't really an IJsonRpcMessage, but there's no API to write a single item of a different type, // so we fib and special-case the "endpoint" event type in the formatter. if (!_outgoingSseChannel.Writer.TryWrite(new SseItem(null, "endpoint"))) @@ -52,10 +39,23 @@ void WriteJsonRpcMessageToBuffer(SseItem item, IBufferWriter item, IBufferWriter writer) + { + if (item.EventType == "endpoint") + { + writer.Write(Encoding.UTF8.GetBytes(messageEndpoint)); + return; + } + + JsonSerializer.Serialize(GetUtf8JsonWriter(writer), item.Data, McpJsonUtilities.DefaultOptions.GetTypeInfo()); + } + /// public ChannelReader MessageReader => _incomingChannel.Reader; diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index aa35ac33..547c2237 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -14,8 +14,8 @@ internal sealed class McpServer : McpJsonRpcEndpoint, IMcpServer private readonly EventHandler? _toolsChangedDelegate; private readonly EventHandler? _promptsChangedDelegate; - private ITransport _sessionTransport; private string _endpointName; + private int _started; /// /// Creates a new instance of . @@ -32,7 +32,6 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? Throw.IfNull(transport); Throw.IfNull(options); - _sessionTransport = transport; ServerOptions = options; Services = serviceProvider; _endpointName = $"Server ({options.ServerInfo.Name} {options.ServerInfo.Version})"; @@ -74,6 +73,8 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? SetPromptsHandler(options); SetResourcesHandler(options); SetSetLoggingLevelHandler(options); + + StartSession(transport); } public ServerCapabilities? ServerCapabilities { get; set; } @@ -96,11 +97,16 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? /// public async Task RunAsync(CancellationToken cancellationToken = default) { + if (Interlocked.Exchange(ref _started, 1) != 0) + { + throw new InvalidOperationException($"{nameof(RunAsync)} must only be called once."); + } + try { - // Start processing messages - StartSession(_sessionTransport, fullSessionCancellationToken: cancellationToken); - await MessageProcessingTask.ConfigureAwait(false); + using var _ = cancellationToken.Register(static s => ((McpServer)s!).CancelSession(), this); + // The McpServer ctor always calls StartSession, so MessageProcessingTask is always set. + await MessageProcessingTask!.ConfigureAwait(false); } finally { diff --git a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs b/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs index 3f7417af..915dfa18 100644 --- a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs +++ b/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs @@ -22,7 +22,6 @@ internal abstract class McpJsonRpcEndpoint : IAsyncDisposable private McpSession? _session; private CancellationTokenSource? _sessionCts; - private int _started; private readonly SemaphoreSlim _disposeLock = new(1, 1); private bool _disposed; @@ -61,18 +60,15 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella protected Task? MessageProcessingTask { get; set; } [MemberNotNull(nameof(MessageProcessingTask))] - protected void StartSession(ITransport sessionTransport, CancellationToken fullSessionCancellationToken = default) + protected void StartSession(ITransport sessionTransport) { - if (Interlocked.Exchange(ref _started, 1) != 0) - { - throw new InvalidOperationException("The MCP session has already stared."); - } - - _sessionCts = CancellationTokenSource.CreateLinkedTokenSource(fullSessionCancellationToken); + _sessionCts = new CancellationTokenSource(); _session = new McpSession(sessionTransport, EndpointName, _requestHandlers, _notificationHandlers, _logger); MessageProcessingTask = _session.ProcessMessagesAsync(_sessionCts.Token); } + protected void CancelSession() => _sessionCts?.Cancel(); + public async ValueTask DisposeAsync() { using var _ = await _disposeLock.LockAsync().ConfigureAwait(false); diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs index 25c5123e..ae640ecc 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs @@ -1,8 +1,6 @@ -using ModelContextProtocol.Protocol.Transport; -using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Utils; -using Moq; namespace ModelContextProtocol.Tests.Server; @@ -25,7 +23,8 @@ public McpServerFactoryTests(ITestOutputHelper testOutputHelper) public async Task Create_Should_Initialize_With_Valid_Parameters() { // Arrange & Act - await using IMcpServer server = McpServerFactory.Create(Mock.Of(), _options, LoggerFactory); + await using var transport = new TestServerTransport(); + await using IMcpServer server = McpServerFactory.Create(transport, _options, LoggerFactory); // Assert Assert.NotNull(server); @@ -39,9 +38,10 @@ public void Create_Throws_For_Null_ServerTransport() } [Fact] - public void Create_Throws_For_Null_Options() + public async Task Create_Throws_For_Null_Options() { // Arrange, Act & Assert - Assert.Throws("serverOptions", () => McpServerFactory.Create(Mock.Of(), null!, LoggerFactory)); + await using var transport = new TestServerTransport(); + Assert.Throws("serverOptions", () => McpServerFactory.Create(transport, null!, LoggerFactory)); } } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index bded8e99..97f6b28b 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -1,27 +1,21 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Utils; -using Moq; using System.Reflection; namespace ModelContextProtocol.Tests.Server; public class McpServerTests : LoggedTest { - private readonly Mock _serverTransport; private readonly McpServerOptions _options; - private readonly IServiceProvider _serviceProvider; public McpServerTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { - _serverTransport = new Mock(); _options = CreateOptions(); - _serviceProvider = new ServiceCollection().BuildServiceProvider(); } private static McpServerOptions CreateOptions(ServerCapabilities? capabilities = null) @@ -39,7 +33,8 @@ private static McpServerOptions CreateOptions(ServerCapabilities? capabilities = public async Task Constructor_Should_Initialize_With_Valid_Parameters() { // Arrange & Act - await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, _serviceProvider); + await using var transport = new TestServerTransport(); + await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); // Assert Assert.NotNull(server); @@ -49,21 +44,23 @@ public async Task Constructor_Should_Initialize_With_Valid_Parameters() public void Constructor_Throws_For_Null_Transport() { // Arrange, Act & Assert - Assert.Throws(() => McpServerFactory.Create(null!, _options, LoggerFactory, _serviceProvider)); + Assert.Throws(() => McpServerFactory.Create(null!, _options, LoggerFactory)); } [Fact] - public void Constructor_Throws_For_Null_Options() + public async Task Constructor_Throws_For_Null_Options() { // Arrange, Act & Assert - Assert.Throws(() => McpServerFactory.Create(_serverTransport.Object, null!, LoggerFactory, _serviceProvider)); + await using var transport = new TestServerTransport(); + Assert.Throws(() => McpServerFactory.Create(transport, null!, LoggerFactory)); } [Fact] public async Task Constructor_Does_Not_Throw_For_Null_Logger() { // Arrange & Act - await using var server = McpServerFactory.Create(_serverTransport.Object, _options, null, _serviceProvider); + await using var transport = new TestServerTransport(); + await using var server = McpServerFactory.Create(transport, _options, null); // Assert Assert.NotNull(server); @@ -73,7 +70,8 @@ public async Task Constructor_Does_Not_Throw_For_Null_Logger() public async Task Constructor_Does_Not_Throw_For_Null_ServiceProvider() { // Arrange & Act - await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, null); + await using var transport = new TestServerTransport(); + await using var server = McpServerFactory.Create(transport, _options, LoggerFactory, null); // Assert Assert.NotNull(server); @@ -83,27 +81,23 @@ public async Task Constructor_Does_Not_Throw_For_Null_ServiceProvider() public async Task RunAsync_Should_Throw_InvalidOperationException_If_Already_Running() { // Arrange - await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, _serviceProvider); + await using var transport = new TestServerTransport(); + await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); var runTask = server.RunAsync(TestContext.Current.CancellationToken); // Act & Assert await Assert.ThrowsAsync(() => server.RunAsync(TestContext.Current.CancellationToken)); - try - { - await runTask; - } - catch (NullReferenceException) - { - // _serverTransport.Object returns a null MessageReader - } + await transport.DisposeAsync(); + await runTask; } [Fact] public async Task RequestSamplingAsync_Should_Throw_McpServerException_If_Client_Does_Not_Support_Sampling() { // Arrange - await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, _serviceProvider); + await using var transport = new TestServerTransport(); + await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities()); var action = () => server.RequestSamplingAsync(new CreateMessageRequestParams { Messages = [] }, CancellationToken.None); @@ -117,7 +111,7 @@ public async Task RequestSamplingAsync_Should_SendRequest() { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory, _serviceProvider); + await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities { Sampling = new SamplingCapability() }); var runTask = server.RunAsync(TestContext.Current.CancellationToken); @@ -138,7 +132,8 @@ public async Task RequestSamplingAsync_Should_SendRequest() public async Task RequestRootsAsync_Should_Throw_McpServerException_If_Client_Does_Not_Support_Roots() { // Arrange - await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, _serviceProvider); + await using var transport = new TestServerTransport(); + await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities()); // Act & Assert @@ -150,7 +145,7 @@ public async Task RequestRootsAsync_Should_SendRequest() { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory, _serviceProvider); + await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities { Roots = new RootsCapability() }); var runTask = server.RunAsync(TestContext.Current.CancellationToken); @@ -507,7 +502,7 @@ private async Task Can_Handle_Requests(ServerCapabilities? serverCapabilities, s var options = CreateOptions(serverCapabilities); configureOptions?.Invoke(options); - await using var server = McpServerFactory.Create(transport, options, LoggerFactory, _serviceProvider); + await using var server = McpServerFactory.Create(transport, options, LoggerFactory); var runTask = server.RunAsync(TestContext.Current.CancellationToken); @@ -542,7 +537,7 @@ private async Task Throws_Exception_If_No_Handler_Assigned(ServerCapabilities se await using var transport = new TestServerTransport(); var options = CreateOptions(serverCapabilities); - Assert.Throws(() => McpServerFactory.Create(transport, options, LoggerFactory, _serviceProvider)); + Assert.Throws(() => McpServerFactory.Create(transport, options, LoggerFactory)); } [Fact] @@ -553,7 +548,6 @@ public async Task AsSamplingChatClient_NoSamplingSupport_Throws() Assert.Throws("server", () => server.AsSamplingChatClient()); } - [Fact] public async Task AsSamplingChatClient_HandlesRequestResponse() { @@ -583,6 +577,26 @@ public async Task AsSamplingChatClient_HandlesRequestResponse() Assert.Equal(ChatRole.Assistant, response.Messages[0].Role); } + [Fact] + public async Task Can_SendMessage_Before_RunAsync() + { + await using var transport = new TestServerTransport(); + await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + + var logNotification = new JsonRpcNotification() + { + Method = NotificationMethods.LoggingMessageNotification + }; + await server.SendMessageAsync(logNotification, TestContext.Current.CancellationToken); + + var runTask = server.RunAsync(TestContext.Current.CancellationToken); + await transport.DisposeAsync(); + await runTask; + + Assert.NotEmpty(transport.SentMessages); + Assert.Same(logNotification, transport.SentMessages[0]); + } + private static void SetClientCapabilities(IMcpServer server, ClientCapabilities capabilities) { PropertyInfo? property = server.GetType().GetProperty("ClientCapabilities", BindingFlags.Public | BindingFlags.Instance); @@ -644,7 +658,7 @@ public async Task NotifyProgress_Should_Be_Handled() var notificationReceived = new TaskCompletionSource(); - var server = McpServerFactory.Create(transport, options, LoggerFactory, _serviceProvider); + var server = McpServerFactory.Create(transport, options, LoggerFactory); server.AddNotificationHandler(NotificationMethods.ProgressNotification, notification => { notificationReceived.SetResult(notification);