Skip to content

Make supposedly unreachable code less reachable #178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo
{
throw new Exception($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created.");
}
await using var server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider);

try
{
var transportTask = transport.RunAsync(cancellationToken: requestAborted);
await using var server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider);

try
{
Expand Down Expand Up @@ -85,7 +85,7 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo

if (!_sessions.TryGetValue(sessionId.ToString(), out var transport))
{
await Results.BadRequest($"Session {sessionId} not found.").ExecuteAsync(context);
await Results.BadRequest($"Session ID not found.").ExecuteAsync(context);
return;
}

Expand Down
32 changes: 15 additions & 17 deletions src/ModelContextProtocol/Client/McpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,29 +82,27 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)
{
// Connect transport
_sessionTransport = await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false);
// We don't want the ConnectAsync token to cancel the session after we've successfully connected.
// The base class handles cleaning up the session in DisposeAsync without our help.
StartSession(_sessionTransport, fullSessionCancellationToken: CancellationToken.None);
StartSession(_sessionTransport);

// Perform initialization sequence
using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
initializationCts.CancelAfter(_options.InitializationTimeout);

try
{
// Send initialize request
var initializeResponse = await SendRequestAsync<InitializeResult>(
new JsonRpcRequest
{
Method = RequestMethods.Initialize,
Params = new InitializeRequestParams()
try
{
// Send initialize request
var initializeResponse = await SendRequestAsync<InitializeResult>(
new JsonRpcRequest
{
ProtocolVersion = _options.ProtocolVersion,
Capabilities = _options.Capabilities ?? new ClientCapabilities(),
ClientInfo = _options.ClientInfo
}
},
initializationCts.Token).ConfigureAwait(false);
Method = RequestMethods.Initialize,
Params = new InitializeRequestParams()
{
ProtocolVersion = _options.ProtocolVersion,
Capabilities = _options.Capabilities ?? new ClientCapabilities(),
ClientInfo = _options.ClientInfo
}
},
initializationCts.Token).ConfigureAwait(false);

// Store server information
_logger.ServerCapabilitiesReceived(EndpointName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,30 +32,30 @@ public sealed class SseResponseStreamTransport(Stream sseResponseStream, string
/// <returns>A task representing the send loop that writes JSON-RPC messages to the SSE response stream.</returns>
public Task RunAsync(CancellationToken cancellationToken)
{
void WriteJsonRpcMessageToBuffer(SseItem<IJsonRpcMessage?> item, IBufferWriter<byte> writer)
{
if (item.EventType == "endpoint")
{
writer.Write(Encoding.UTF8.GetBytes(messageEndpoint));
return;
}

JsonSerializer.Serialize(GetUtf8JsonWriter(writer), item.Data, McpJsonUtilities.DefaultOptions.GetTypeInfo<IJsonRpcMessage?>());
}

IsConnected = true;

// The very first SSE event isn't really an IJsonRpcMessage, but there's no API to write a single item of a different type,
// so we fib and special-case the "endpoint" event type in the formatter.
if (!_outgoingSseChannel.Writer.TryWrite(new SseItem<IJsonRpcMessage?>(null, "endpoint")))
{
throw new InvalidOperationException($"You must call ${nameof(RunAsync)} before calling ${nameof(SendMessageAsync)}.");
}

IsConnected = true;

var sseItems = _outgoingSseChannel.Reader.ReadAllAsync(cancellationToken);
return _sseWriteTask = SseFormatter.WriteAsync(sseItems, sseResponseStream, WriteJsonRpcMessageToBuffer, cancellationToken);
}

private void WriteJsonRpcMessageToBuffer(SseItem<IJsonRpcMessage?> item, IBufferWriter<byte> writer)
{
if (item.EventType == "endpoint")
{
writer.Write(Encoding.UTF8.GetBytes(messageEndpoint));
return;
}

JsonSerializer.Serialize(GetUtf8JsonWriter(writer), item.Data, McpJsonUtilities.DefaultOptions.GetTypeInfo<IJsonRpcMessage?>());
}

/// <inheritdoc/>
public ChannelReader<IJsonRpcMessage> MessageReader => _incomingChannel.Reader;

Expand Down
16 changes: 11 additions & 5 deletions src/ModelContextProtocol/Server/McpServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ internal sealed class McpServer : McpJsonRpcEndpoint, IMcpServer
private readonly EventHandler? _toolsChangedDelegate;
private readonly EventHandler? _promptsChangedDelegate;

private ITransport _sessionTransport;
private string _endpointName;
private int _started;

/// <summary>
/// Creates a new instance of <see cref="McpServer"/>.
Expand All @@ -32,7 +32,6 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory?
Throw.IfNull(transport);
Throw.IfNull(options);

_sessionTransport = transport;
ServerOptions = options;
Services = serviceProvider;
_endpointName = $"Server ({options.ServerInfo.Name} {options.ServerInfo.Version})";
Expand Down Expand Up @@ -74,6 +73,8 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory?
SetPromptsHandler(options);
SetResourcesHandler(options);
SetSetLoggingLevelHandler(options);

StartSession(transport);
}

public ServerCapabilities? ServerCapabilities { get; set; }
Expand All @@ -96,11 +97,16 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory?
/// <inheritdoc />
public async Task RunAsync(CancellationToken cancellationToken = default)
{
if (Interlocked.Exchange(ref _started, 1) != 0)
{
throw new InvalidOperationException($"{nameof(RunAsync)} must only be called once.");
}

try
{
// Start processing messages
StartSession(_sessionTransport, fullSessionCancellationToken: cancellationToken);
await MessageProcessingTask.ConfigureAwait(false);
using var _ = cancellationToken.Register(static s => ((McpServer)s!).CancelSession(), this);
// The McpServer ctor always calls StartSession, so MessageProcessingTask is always set.
await MessageProcessingTask!.ConfigureAwait(false);
}
finally
{
Expand Down
12 changes: 4 additions & 8 deletions src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ internal abstract class McpJsonRpcEndpoint : IAsyncDisposable

private McpSession? _session;
private CancellationTokenSource? _sessionCts;
private int _started;

private readonly SemaphoreSlim _disposeLock = new(1, 1);
private bool _disposed;
Expand Down Expand Up @@ -61,18 +60,15 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella
protected Task? MessageProcessingTask { get; set; }

[MemberNotNull(nameof(MessageProcessingTask))]
protected void StartSession(ITransport sessionTransport, CancellationToken fullSessionCancellationToken = default)
protected void StartSession(ITransport sessionTransport)
{
if (Interlocked.Exchange(ref _started, 1) != 0)
{
throw new InvalidOperationException("The MCP session has already stared.");
}

_sessionCts = CancellationTokenSource.CreateLinkedTokenSource(fullSessionCancellationToken);
_sessionCts = new CancellationTokenSource();
_session = new McpSession(sessionTransport, EndpointName, _requestHandlers, _notificationHandlers, _logger);
MessageProcessingTask = _session.ProcessMessagesAsync(_sessionCts.Token);
}

protected void CancelSession() => _sessionCts?.Cancel();

public async ValueTask DisposeAsync()
{
using var _ = await _disposeLock.LockAsync().ConfigureAwait(false);
Expand Down
12 changes: 6 additions & 6 deletions tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Server;
using ModelContextProtocol.Tests.Utils;
using Moq;

namespace ModelContextProtocol.Tests.Server;

Expand All @@ -25,7 +23,8 @@ public McpServerFactoryTests(ITestOutputHelper testOutputHelper)
public async Task Create_Should_Initialize_With_Valid_Parameters()
{
// Arrange & Act
await using IMcpServer server = McpServerFactory.Create(Mock.Of<ITransport>(), _options, LoggerFactory);
await using var transport = new TestServerTransport();
await using IMcpServer server = McpServerFactory.Create(transport, _options, LoggerFactory);

// Assert
Assert.NotNull(server);
Expand All @@ -39,9 +38,10 @@ public void Create_Throws_For_Null_ServerTransport()
}

[Fact]
public void Create_Throws_For_Null_Options()
public async Task Create_Throws_For_Null_Options()
{
// Arrange, Act & Assert
Assert.Throws<ArgumentNullException>("serverOptions", () => McpServerFactory.Create(Mock.Of<ITransport>(), null!, LoggerFactory));
await using var transport = new TestServerTransport();
Assert.Throws<ArgumentNullException>("serverOptions", () => McpServerFactory.Create(transport, null!, LoggerFactory));
}
}
Loading