diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs index 0749eed7..53d41cc1 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs @@ -107,12 +107,53 @@ public async Task ConnectAsync_Should_Connect_Successfully() [Fact] public async Task ConnectAsync_Throws_If_Already_Connected() { - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, NullLoggerFactory.Instance); - transport.GetType().BaseType!.GetField("_isConnected", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)?.SetValue(transport, true); + using var mockHttpHandler = new MockHttpHandler(); + using var httpClient = new HttpClient(mockHttpHandler); + await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, NullLoggerFactory.Instance); + using var mreConnected = new ManualResetEventSlim(false); + using var mreDone = new ManualResetEventSlim(false); + var callIndex = 0; + + mockHttpHandler.RequestHandler = (request) => + { + switch (callIndex++) + { + case 0: + return Task.FromResult(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent("event: endpoint\r\ndata: http://localhost\r\n\r\n") + }); + case 1: + mreConnected.Set(); + mreDone.Wait(); + return Task.FromResult(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent("") + }); + default: + return Task.FromResult(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent("") + }); + } + }; + var task = Task.Run(async () => + { + await transport.ConnectAsync(TestContext.Current.CancellationToken); + }, TestContext.Current.CancellationToken); + + mreConnected.Wait(TestContext.Current.CancellationToken); + Assert.True(transport.IsConnected); var action = async () => await transport.ConnectAsync(); var exception = await Assert.ThrowsAsync(action); Assert.Equal("Transport is already connected", exception.Message); + mreDone.Set(); + await transport.CloseAsync(); + await task; } [Fact]