From da548f9818f79637273683961ce3584ad747c680 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 19 Mar 2025 15:08:43 +0100 Subject: [PATCH] Improve client test reliability and execution time This change uses VirtualTimeScheduler and pretends enough time has passed to trigger a timeout on the initialization. Another problem with reliability of the tests was that the used testcontainer for the SSE server does not support multiple clients and the existence of both the global client for the entire suite and some customized local clients in some tests caused responses to be delivered to the other client at some racing situations. Now each test creates a dedicated client and performs cleanup locally. While these tests were improved, two other issues were found and fixed. The first one is that the closeGracefully of DefaultMcpSession was not lazy and would trigger connection disposal before the returned Mono was subscribed. The second one was dealing with closing the StdIo client before the process was started. In such a case there should not be an error but rather a warning and successful completion. --- .../client/AbstractMcpAsyncClientTests.java | 513 +++++++++++------- .../client/AbstractMcpSyncClientTests.java | 363 ++++++++----- .../transport/StdioClientTransport.java | 7 +- .../spec/DefaultMcpSession.java | 6 +- .../client/AbstractMcpAsyncClientTests.java | 462 +++++++++------- .../client/AbstractMcpSyncClientTests.java | 363 ++++++++----- .../client/StdioMcpSyncClientTests.java | 20 +- 7 files changed, 1016 insertions(+), 718 deletions(-) diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index a8a59a63..033139ad 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -6,7 +6,10 @@ import java.time.Duration; import java.util.Map; +import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import java.util.function.Function; import io.modelcontextprotocol.spec.ClientMcpTransport; @@ -44,10 +47,6 @@ */ public abstract class AbstractMcpAsyncClientTests { - private McpAsyncClient mcpAsyncClient; - - protected ClientMcpTransport mcpTransport; - private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!"; abstract protected ClientMcpTransport createMcpTransport(); @@ -66,25 +65,47 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } - @BeforeEach - void setUp() { - onStart(); - this.mcpTransport = createMcpTransport(); + McpAsyncClient client(ClientMcpTransport transport) { + return client(transport, Function.identity()); + } + + McpAsyncClient client(ClientMcpTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { - mcpAsyncClient = McpClient.async(mcpTransport) + McpClient.AsyncSpec builder = McpClient.async(transport) .requestTimeout(getRequestTimeout()) .initializationTimeout(getInitializationTimeout()) - .capabilities(ClientCapabilities.builder().roots(true).build()) - .build(); + .capabilities(ClientCapabilities.builder().roots(true).build()); + builder = customizer.apply(builder); + client.set(builder.build()); }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(ClientMcpTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(ClientMcpTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); + } + } + + @BeforeEach + void setUp() { + onStart(); } @AfterEach void tearDown() { - if (mcpAsyncClient != null) { - StepVerifier.create(mcpAsyncClient.closeGracefully()).verifyComplete(); - } onClose(); } @@ -93,258 +114,323 @@ void testConstructorWithInvalidArguments() { assertThatThrownBy(() -> McpClient.async(null).build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); - assertThatThrownBy(() -> McpClient.async(mcpTransport).requestTimeout(null).build()) + assertThatThrownBy(() -> McpClient.async(createMcpTransport()).requestTimeout(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Request timeout must not be null"); } @Test void testListToolsWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.listTools(null)).expectErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing tools"); - }).verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.withVirtualTime(() -> mcpAsyncClient.listTools(null)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before listing tools")) + .verify(); + }); } @Test void testListTools() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(null))) - .consumeNextWith(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(null))) + .consumeNextWith(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); - }) - .verifyComplete(); + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }) + .verifyComplete(); + }); } @Test void testPingWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.ping()) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before pinging the server")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.withVirtualTime(() -> mcpAsyncClient.ping()) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before pinging the " + "server")) + .verify(); + }); } @Test void testPing() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())).consumeNextWith(callToolResult -> { - }).verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())) + .expectNextCount(1) + .verifyComplete(); + }); } @Test void testCallToolWithoutInitialization() { - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); + withClient(createMcpTransport(), mcpAsyncClient -> { + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - StepVerifier.create(mcpAsyncClient.callTool(callToolRequest)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before calling tools")) - .verify(); + StepVerifier.withVirtualTime(() -> mcpAsyncClient.callTool(callToolRequest)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before calling tools")) + .verify(); + }); } @Test void testCallTool() { - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); + withClient(createMcpTransport(), mcpAsyncClient -> { + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest))) - .consumeNextWith(callToolResult -> { - assertThat(callToolResult).isNotNull(); - assertThat(callToolResult.content()).isNotNull(); - assertThat(callToolResult.isError()).isNull(); - }) - .verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest))) + .consumeNextWith(callToolResult -> { + assertThat(callToolResult).isNotNull().satisfies(result -> { + assertThat(result.content()).isNotNull(); + assertThat(result.isError()).isNull(); + }); + }) + .verifyComplete(); + }); } @Test void testCallToolWithInvalidTool() { - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", ECHO_TEST_MESSAGE)); + withClient(createMcpTransport(), mcpAsyncClient -> { + CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", + Map.of("message", ECHO_TEST_MESSAGE)); - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(invalidRequest))) - .expectError(Exception.class) - .verify(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(invalidRequest))) + .consumeErrorWith( + e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Unknown tool: nonexistent_tool")) + .verify(); + }); } @Test void testListResourcesWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.listResources(null)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before listing resources")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.withVirtualTime(() -> mcpAsyncClient.listResources(null)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before listing resources")) + .verify(); + }); } @Test void testListResources() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(null))) - .consumeNextWith(resources -> { - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } - }); - }) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(null))) + .consumeNextWith(resources -> { + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); + }) + .verifyComplete(); + }); } @Test void testMcpAsyncClientState() { - assertThat(mcpAsyncClient).isNotNull(); + withClient(createMcpTransport(), mcpAsyncClient -> { + assertThat(mcpAsyncClient).isNotNull(); + }); } @Test void testListPromptsWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.listPrompts(null)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before listing prompts")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.withVirtualTime(() -> mcpAsyncClient.listPrompts(null)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before listing prompts")) + .verify(); + }); } @Test void testListPrompts() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(null))) - .consumeNextWith(prompts -> { - assertThat(prompts).isNotNull().satisfies(result -> { - assertThat(result.prompts()).isNotNull(); - - if (!result.prompts().isEmpty()) { - Prompt firstPrompt = result.prompts().get(0); - assertThat(firstPrompt.name()).isNotNull(); - assertThat(firstPrompt.description()).isNotNull(); - } - }); - }) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(null))) + .consumeNextWith(prompts -> { + assertThat(prompts).isNotNull().satisfies(result -> { + assertThat(result.prompts()).isNotNull(); + + if (!result.prompts().isEmpty()) { + Prompt firstPrompt = result.prompts().get(0); + assertThat(firstPrompt.name()).isNotNull(); + assertThat(firstPrompt.description()).isNotNull(); + } + }); + }) + .verifyComplete(); + }); } @Test void testGetPromptWithoutInitialization() { - GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); + withClient(createMcpTransport(), mcpAsyncClient -> { + GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - StepVerifier.create(mcpAsyncClient.getPrompt(request)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before getting prompts")) - .verify(); + StepVerifier.withVirtualTime(() -> mcpAsyncClient.getPrompt(request)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before getting prompts")) + .verify(); + }); } @Test void testGetPrompt() { - GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.getPrompt(request))) - .consumeNextWith(prompt -> { - assertThat(prompt).isNotNull().satisfies(result -> { - assertThat(result.messages()).isNotEmpty(); - assertThat(result.messages()).hasSize(1); - }); - }) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier + .create(mcpAsyncClient.initialize() + .then(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of())))) + .consumeNextWith(prompt -> { + assertThat(prompt).isNotNull().satisfies(result -> { + assertThat(result.messages()).isNotEmpty(); + assertThat(result.messages()).hasSize(1); + }); + }) + .verifyComplete(); + }); } @Test void testRootsListChangedWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.rootsListChangedNotification()) - .expectErrorMatches(error -> error instanceof McpError && error.getMessage() - .equals("Client must be initialized before sending roots list changed notification")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.withVirtualTime(() -> mcpAsyncClient.rootsListChangedNotification()) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before sending roots list changed notification")) + .verify(); + }); } @Test void testRootsListChanged() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification())) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification())) + .verifyComplete(); + }); } @Test void testInitializeWithRootsListProviders() { - var transport = createMcpTransport(); - - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .roots(new Root("file:///test/path", "test-root")) - .build(); - - StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), + client -> { + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + }); } @Test void testAddRoot() { - Root newRoot = new Root("file:///new/test/path", "new-test-root"); - - StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + Root newRoot = new Root("file:///new/test/path", "new-test-root"); + StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete(); + }); } @Test void testAddRootWithNullValue() { - StepVerifier.create(mcpAsyncClient.addRoot(null)) - .expectErrorMatches(error -> error.getMessage().contains("Root must not be null")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.addRoot(null)) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Root must not be null")) + .verify(); + }); } @Test void testRemoveRoot() { - Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + withClient(createMcpTransport(), mcpAsyncClient -> { + Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + StepVerifier.create(mcpAsyncClient.addRoot(root)).verifyComplete(); - StepVerifier.create(mcpAsyncClient.addRoot(root).then(mcpAsyncClient.removeRoot(root.uri()))).verifyComplete(); + StepVerifier.create(mcpAsyncClient.removeRoot(root.uri())).verifyComplete(); + }); } @Test void testRemoveNonExistentRoot() { - StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) - .expectErrorMatches(error -> error.getMessage().contains("Root with uri 'nonexistent-uri' not found")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Root with uri 'nonexistent-uri' not found")) + .verify(); + }); } @Test @Disabled void testReadResource() { - StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - StepVerifier.create(mcpAsyncClient.readResource(firstResource)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.contents()).isNotNull(); - }).verifyComplete(); - } - }).verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); + StepVerifier.create(mcpAsyncClient.readResource(firstResource)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.contents()).isNotNull(); + }).verifyComplete(); + } + }).verifyComplete(); + }); } @Test void testListResourceTemplatesWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.listResourceTemplates()) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before listing resource templates")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.withVirtualTime(() -> mcpAsyncClient.listResourceTemplates()) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before listing resource templates")) + .verify(); + }); } @Test void testListResourceTemplates() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) - .consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); - }) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) + .consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }) + .verifyComplete(); + }); } // @Test void testResourceSubscription() { - StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); - // Test subscribe - StepVerifier.create(mcpAsyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) - .verifyComplete(); + // Test subscribe + StepVerifier.create(mcpAsyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) + .verifyComplete(); - // Test unsubscribe - StepVerifier.create(mcpAsyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) - .verifyComplete(); - } - }).verifyComplete(); + // Test unsubscribe + StepVerifier.create(mcpAsyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) + .verifyComplete(); + } + }).verifyComplete(); + }); } @Test @@ -353,36 +439,44 @@ void testNotificationHandlers() { AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) - .resourcesChangeConsumer(resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) - .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))) - .build(); - - StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + withClient(createMcpTransport(), + builder -> builder + .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) + .resourcesChangeConsumer( + resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) + .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))), + mcpAsyncClient -> { + + var transport = createMcpTransport(); + var client = McpClient.async(transport) + .requestTimeout(getRequestTimeout()) + .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) + .resourcesChangeConsumer( + resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) + .promptsChangeConsumer( + prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))) + .build(); + + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + }); } @Test void testInitializeWithSamplingCapability() { - var transport = createMcpTransport(); - - var capabilities = ClientCapabilities.builder().sampling().build(); - - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .capabilities(capabilities) - .sampling(request -> Mono.just(CreateMessageResult.builder().message("test").model("test-model").build())) + ClientCapabilities capabilities = ClientCapabilities.builder().sampling().build(); + CreateMessageResult createMessageResult = CreateMessageResult.builder() + .message("test") + .model("test-model") .build(); - - StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).sampling(request -> Mono.just(createMessageResult)), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + }); } @Test void testInitializeWithAllCapabilities() { - var transport = createMcpTransport(); - var capabilities = ClientCapabilities.builder() .experimental(Map.of("feature", "test")) .roots(true) @@ -391,18 +485,14 @@ void testInitializeWithAllCapabilities() { Function> samplingHandler = request -> Mono .just(CreateMessageResult.builder().message("test").model("test-model").build()); - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .capabilities(capabilities) - .sampling(samplingHandler) - .build(); - StepVerifier.create(client.initialize()).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.capabilities()).isNotNull(); - }).verifyComplete(); + withClient(createMcpTransport(), builder -> builder.capabilities(capabilities).sampling(samplingHandler), + client -> - StepVerifier.create(client.closeGracefully()).verifyComplete(); + StepVerifier.create(client.initialize()).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.capabilities()).isNotNull(); + }).verifyComplete()); } // --------------------------------------- @@ -411,43 +501,52 @@ void testInitializeWithAllCapabilities() { @Test void testLoggingLevelsWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before setting logging level")) - .verify(); + withClient(createMcpTransport(), + mcpAsyncClient -> StepVerifier + .withVirtualTime(() -> mcpAsyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before setting logging level")) + .verify()); } @Test void testLoggingLevels() { - Mono testAllLevels = mcpAsyncClient.initialize().then(Mono.defer(() -> { - Mono chain = Mono.empty(); - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - chain = chain.then(mcpAsyncClient.setLoggingLevel(level)); - } - return chain; - })); + withClient(createMcpTransport(), mcpAsyncClient -> { + Mono testAllLevels = mcpAsyncClient.initialize().then(Mono.defer(() -> { + Mono chain = Mono.empty(); + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + chain = chain.then(mcpAsyncClient.setLoggingLevel(level)); + } + return chain; + })); - StepVerifier.create(testAllLevels).verifyComplete(); + StepVerifier.create(testAllLevels).verifyComplete(); + }); } @Test void testLoggingConsumer() { AtomicBoolean logReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))) - .build(); + withClient(createMcpTransport(), + builder -> builder.loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + StepVerifier.create(client.closeGracefully()).verifyComplete(); + + }); - StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); } @Test void testLoggingWithNullNotification() { - StepVerifier.create(mcpAsyncClient.setLoggingLevel(null)) - .expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.setLoggingLevel(null)) + .expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null")) + .verify(); + }); } } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 0f83e31e..032f8684 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -7,6 +7,9 @@ import java.time.Duration; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; import io.modelcontextprotocol.spec.ClientMcpTransport; import io.modelcontextprotocol.spec.McpError; @@ -27,6 +30,10 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; @@ -40,12 +47,8 @@ */ public abstract class AbstractMcpSyncClientTests { - private McpSyncClient mcpSyncClient; - private static final String TEST_MESSAGE = "Hello MCP Spring AI!"; - protected ClientMcpTransport mcpTransport; - abstract protected ClientMcpTransport createMcpTransport(); protected void onStart() { @@ -62,254 +65,322 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } - @BeforeEach - void setUp() { - onStart(); - this.mcpTransport = createMcpTransport(); + McpSyncClient client(ClientMcpTransport transport) { + return client(transport, Function.identity()); + } + + McpSyncClient client(ClientMcpTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { - mcpSyncClient = McpClient.sync(mcpTransport) + McpClient.SyncSpec builder = McpClient.sync(transport) .requestTimeout(getRequestTimeout()) .initializationTimeout(getInitializationTimeout()) - .capabilities(ClientCapabilities.builder().roots(true).build()) - .build(); + .capabilities(ClientCapabilities.builder().roots(true).build()); + builder = customizer.apply(builder); + client.set(builder.build()); }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(ClientMcpTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(ClientMcpTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + assertThat(client.closeGracefully()).isTrue(); + } + } + + @BeforeEach + void setUp() { + onStart(); + } @AfterEach void tearDown() { - if (mcpSyncClient != null) { - assertThatCode(() -> mcpSyncClient.close()).doesNotThrowAnyException(); - } onClose(); } + static final Object DUMMY_RETURN_VALUE = new Object(); + + void verifyNotificationTimesOut(Consumer operation, String action) { + verifyCallTimesOut(client -> { + operation.accept(client); + return DUMMY_RETURN_VALUE; + }, action); + } + + void verifyCallTimesOut(Function operation, String action) { + withClient(createMcpTransport(), mcpSyncClient -> { + // This scheduler is not replaced by virtual time scheduler + Scheduler customScheduler = Schedulers.newBoundedElastic(1, 1, "actualBoundedElastic"); + + StepVerifier.withVirtualTime(() -> Mono.fromSupplier(() -> operation.apply(mcpSyncClient)) + // offload the blocking call to the real scheduler + .subscribeOn(customScheduler)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before " + action)) + .verify(); + + customScheduler.dispose(); + }); + } + @Test void testConstructorWithInvalidArguments() { assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); - assertThatThrownBy(() -> McpClient.sync(mcpTransport).requestTimeout(null).build()) + assertThatThrownBy(() -> McpClient.sync(createMcpTransport()).requestTimeout(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Request timeout must not be null"); } @Test void testListToolsWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.listTools(null)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing tools"); + verifyCallTimesOut(client -> client.listTools(null), "listing tools"); } @Test void testListTools() { - mcpSyncClient.initialize(); - ListToolsResult tools = mcpSyncClient.listTools(null); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListToolsResult tools = mcpSyncClient.listTools(null); - assertThat(tools).isNotNull().satisfies(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); + assertThat(tools).isNotNull().satisfies(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }); }); } @Test void testCallToolsWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4)))) - .isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before calling tools"); + verifyCallTimesOut(client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), + "calling tools"); } @Test void testCallTools() { - mcpSyncClient.initialize(); - CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))); - assertThat(toolResult).isNotNull().satisfies(result -> { + assertThat(toolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).hasSize(1); + assertThat(result.content()).hasSize(1); - TextContent content = (TextContent) result.content().get(0); + TextContent content = (TextContent) result.content().get(0); - assertThat(content).isNotNull(); - assertThat(content.text()).isNotNull(); - assertThat(content.text()).contains("7"); + assertThat(content).isNotNull(); + assertThat(content.text()).isNotNull(); + assertThat(content.text()).contains("7"); + }); }); } @Test void testPingWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.ping()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before pinging the server"); + verifyCallTimesOut(client -> client.ping(), "pinging the server"); } @Test void testPing() { - mcpSyncClient.initialize(); - assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException(); + }); } @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - - assertThatThrownBy(() -> mcpSyncClient.callTool(callToolRequest)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before calling tools"); + verifyCallTimesOut(client -> client.callTool(callToolRequest), "calling tools"); } @Test void testCallTool() { - mcpSyncClient.initialize(); - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - CallToolResult callToolResult = mcpSyncClient.callTool(callToolRequest); + CallToolResult callToolResult = mcpSyncClient.callTool(callToolRequest); - assertThat(callToolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).isNotNull(); - assertThat(result.isError()).isNull(); + assertThat(callToolResult).isNotNull().satisfies(result -> { + assertThat(result.content()).isNotNull(); + assertThat(result.isError()).isNull(); + }); }); } @Test void testCallToolWithInvalidTool() { - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE)); + withClient(createMcpTransport(), mcpSyncClient -> { + CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE)); - assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class); + assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class); + }); } @Test void testRootsListChangedWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.rootsListChangedNotification()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before sending roots list changed notification"); + verifyNotificationTimesOut(client -> client.rootsListChangedNotification(), + "sending roots list changed notification"); } @Test void testRootsListChanged() { - mcpSyncClient.initialize(); - assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException(); + }); } @Test void testListResourcesWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.listResources(null)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resources"); + verifyCallTimesOut(client -> client.listResources(null), "listing resources"); } @Test void testListResources() { - mcpSyncClient.initialize(); - ListResourcesResult resources = mcpSyncClient.listResources(null); - - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourcesResult resources = mcpSyncClient.listResources(null); + + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); }); } @Test void testClientSessionState() { - assertThat(mcpSyncClient).isNotNull(); + withClient(createMcpTransport(), mcpSyncClient -> { + assertThat(mcpSyncClient).isNotNull(); + }); } @Test void testInitializeWithRootsListProviders() { - var transport = createMcpTransport(); - - var client = McpClient.sync(transport) - .requestTimeout(getRequestTimeout()) - .roots(new Root("file:///test/path", "test-root")) - .build(); + withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), + mcpSyncClient -> { - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); + assertThatCode(() -> { + mcpSyncClient.initialize(); + mcpSyncClient.close(); + }).doesNotThrowAnyException(); + }); } @Test void testAddRoot() { - Root newRoot = new Root("file:///new/test/path", "new-test-root"); - assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + Root newRoot = new Root("file:///new/test/path", "new-test-root"); + assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException(); + }); } @Test void testAddRootWithNullValue() { - assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null"); + withClient(createMcpTransport(), mcpSyncClient -> { + assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null"); + }); } @Test void testRemoveRoot() { - Root root = new Root("file:///test/path/to/remove", "root-to-remove"); - assertThatCode(() -> { - mcpSyncClient.addRoot(root); - mcpSyncClient.removeRoot(root.uri()); - }).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + assertThatCode(() -> { + mcpSyncClient.addRoot(root); + mcpSyncClient.removeRoot(root.uri()); + }).doesNotThrowAnyException(); + }); } @Test void testRemoveNonExistentRoot() { - assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri")) - .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); + withClient(createMcpTransport(), mcpSyncClient -> { + assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri")) + .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); + }); } @Test void testReadResourceWithoutInitialization() { - assertThatThrownBy(() -> { - Resource resource = new Resource("test://uri", "Test Resource", null, null, null); - mcpSyncClient.readResource(resource); - }).isInstanceOf(McpError.class).hasMessage("Client must be initialized before reading resources"); + Resource resource = new Resource("test://uri", "Test Resource", null, null, null); + verifyCallTimesOut(client -> client.readResource(resource), "reading resources"); } @Test void testReadResource() { - mcpSyncClient.initialize(); - ListResourcesResult resources = mcpSyncClient.listResources(null); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourcesResult resources = mcpSyncClient.listResources(null); - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - ReadResourceResult result = mcpSyncClient.readResource(firstResource); + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); + ReadResourceResult result = mcpSyncClient.readResource(firstResource); - assertThat(result).isNotNull(); - assertThat(result.contents()).isNotNull(); - } + assertThat(result).isNotNull(); + assertThat(result.contents()).isNotNull(); + } + }); } @Test void testListResourceTemplatesWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.listResourceTemplates(null)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resource templates"); + verifyCallTimesOut(client -> client.listResourceTemplates(null), "listing resource templates"); } @Test void testListResourceTemplates() { - mcpSyncClient.initialize(); - ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(null); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(null); - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }); } // @Test void testResourceSubscription() { - ListResourcesResult resources = mcpSyncClient.listResources(null); + withClient(createMcpTransport(), mcpSyncClient -> { + ListResourcesResult resources = mcpSyncClient.listResources(null); - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); - // Test subscribe - assertThatCode(() -> mcpSyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) - .doesNotThrowAnyException(); + // Test subscribe + assertThatCode(() -> mcpSyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) + .doesNotThrowAnyException(); - // Test unsubscribe - assertThatCode(() -> mcpSyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) - .doesNotThrowAnyException(); - } + // Test unsubscribe + assertThatCode(() -> mcpSyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) + .doesNotThrowAnyException(); + } + }); } @Test @@ -318,18 +389,17 @@ void testNotificationHandlers() { AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - var client = McpClient.sync(transport) - .requestTimeout(getRequestTimeout()) - .toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) - .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) - .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)) - .build(); + withClient(createMcpTransport(), + builder -> builder.toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) + .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) + .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)), + client -> { - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); + assertThatCode(() -> { + client.initialize(); + client.close(); + }).doesNotThrowAnyException(); + }); } // --------------------------------------- @@ -338,40 +408,37 @@ void testNotificationHandlers() { @Test void testLoggingLevelsWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG)) - .isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before setting logging level"); + verifyNotificationTimesOut(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), + "setting logging level"); } @Test void testLoggingLevels() { - mcpSyncClient.initialize(); - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - assertThatCode(() -> mcpSyncClient.setLoggingLevel(level)).doesNotThrowAnyException(); - } + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + // Test all logging levels + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + assertThatCode(() -> mcpSyncClient.setLoggingLevel(level)).doesNotThrowAnyException(); + } + }); } @Test void testLoggingConsumer() { AtomicBoolean logReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - - var client = McpClient.sync(transport) - .requestTimeout(getRequestTimeout()) - .loggingConsumer(notification -> logReceived.set(true)) - .build(); - - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); + withClient(createMcpTransport(), builder -> builder.requestTimeout(getRequestTimeout()) + .loggingConsumer(notification -> logReceived.set(true)), client -> { + assertThatCode(() -> { + client.initialize(); + client.close(); + }).doesNotThrowAnyException(); + }); } @Test void testLoggingWithNullNotification() { - assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null)) - .hasMessageContaining("Logging level must not be null"); + withClient(createMcpTransport(), mcpSyncClient -> assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null)) + .hasMessageContaining("Logging level must not be null")); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java index 614c6512..d35db3f8 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java @@ -353,14 +353,15 @@ public Mono closeGracefully() { // Give a short time for any pending messages to be processed return Mono.delay(Duration.ofMillis(100)); - })).then(Mono.fromFuture(() -> { + })).then(Mono.defer(() -> { logger.debug("Sending TERM to process"); if (this.process != null) { this.process.destroy(); - return process.onExit(); + return Mono.fromFuture(process.onExit()); } else { - return CompletableFuture.failedFuture(new RuntimeException("Process not started")); + logger.warn("Process not started"); + return Mono.empty(); } })).doOnNext(process -> { if (process.exitValue() != 0) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java index e2d354f4..46aefafc 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java @@ -270,8 +270,10 @@ public Mono sendNotification(String method, Map params) { */ @Override public Mono closeGracefully() { - this.connection.dispose(); - return transport.closeGracefully(); + return Mono.defer(() -> { + this.connection.dispose(); + return transport.closeGracefully(); + }); } /** diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 39bc4995..72038854 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -6,8 +6,12 @@ import java.time.Duration; import java.util.Map; +import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Supplier; import io.modelcontextprotocol.spec.ClientMcpTransport; import io.modelcontextprotocol.spec.McpError; @@ -45,10 +49,6 @@ // KEEP IN SYNC with the class in mcp-test module public abstract class AbstractMcpAsyncClientTests { - private McpAsyncClient mcpAsyncClient; - - protected ClientMcpTransport mcpTransport; - private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!"; abstract protected ClientMcpTransport createMcpTransport(); @@ -67,285 +67,326 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } - @BeforeEach - void setUp() { - onStart(); - this.mcpTransport = createMcpTransport(); + McpAsyncClient client(ClientMcpTransport transport) { + return client(transport, Function.identity()); + } + + McpAsyncClient client(ClientMcpTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { - mcpAsyncClient = McpClient.async(mcpTransport) + McpClient.AsyncSpec builder = McpClient.async(transport) .requestTimeout(getRequestTimeout()) .initializationTimeout(getInitializationTimeout()) - .capabilities(ClientCapabilities.builder().roots(true).build()) - .build(); + .capabilities(ClientCapabilities.builder().roots(true).build()); + builder = customizer.apply(builder); + client.set(builder.build()); }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(ClientMcpTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(ClientMcpTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); + } + } + + @BeforeEach + void setUp() { + onStart(); } @AfterEach void tearDown() { - if (mcpAsyncClient != null) { - StepVerifier.create(mcpAsyncClient.closeGracefully()).verifyComplete(); - } onClose(); } + void verifyInitializationTimeout(Function> operation, String action) { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.withVirtualTime(() -> operation.apply(mcpAsyncClient)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before " + action)) + .verify(); + }); + } + @Test void testConstructorWithInvalidArguments() { assertThatThrownBy(() -> McpClient.async(null).build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); - assertThatThrownBy(() -> McpClient.async(mcpTransport).requestTimeout(null).build()) + assertThatThrownBy(() -> McpClient.async(createMcpTransport()).requestTimeout(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Request timeout must not be null"); } @Test void testListToolsWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.listTools(null)).expectErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing tools"); - }).verify(); + verifyInitializationTimeout(client -> client.listTools(null), "listing tools"); } @Test void testListTools() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(null))) - .consumeNextWith(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listTools(null))) + .consumeNextWith(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); - }) - .verifyComplete(); + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }) + .verifyComplete(); + }); } @Test void testPingWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.ping()) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before pinging the server")) - .verify(); + verifyInitializationTimeout(client -> client.ping(), "pinging the server"); } @Test void testPing() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())).consumeNextWith(callToolResult -> { - }).verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.ping())) + .expectNextCount(1) + .verifyComplete(); + }); } @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - - StepVerifier.create(mcpAsyncClient.callTool(callToolRequest)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before calling tools")) - .verify(); + verifyInitializationTimeout(client -> client.callTool(callToolRequest), "calling tools"); } @Test void testCallTool() { - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); + withClient(createMcpTransport(), mcpAsyncClient -> { + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE)); - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest))) - .consumeNextWith(callToolResult -> { - assertThat(callToolResult).isNotNull(); - assertThat(callToolResult.content()).isNotNull(); - assertThat(callToolResult.isError()).isNull(); - }) - .verifyComplete(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest))) + .consumeNextWith(callToolResult -> { + assertThat(callToolResult).isNotNull().satisfies(result -> { + assertThat(result.content()).isNotNull(); + assertThat(result.isError()).isNull(); + }); + }) + .verifyComplete(); + }); } @Test void testCallToolWithInvalidTool() { - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", ECHO_TEST_MESSAGE)); + withClient(createMcpTransport(), mcpAsyncClient -> { + CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", + Map.of("message", ECHO_TEST_MESSAGE)); - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(invalidRequest))) - .expectError(Exception.class) - .verify(); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(invalidRequest))) + .consumeErrorWith( + e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Unknown tool: nonexistent_tool")) + .verify(); + }); } @Test void testListResourcesWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.listResources(null)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before listing resources")) - .verify(); + verifyInitializationTimeout(client -> client.listResources(null), "listing resources"); } @Test void testListResources() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(null))) - .consumeNextWith(resources -> { - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } - }); - }) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResources(null))) + .consumeNextWith(resources -> { + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); + }) + .verifyComplete(); + }); } @Test void testMcpAsyncClientState() { - assertThat(mcpAsyncClient).isNotNull(); + withClient(createMcpTransport(), mcpAsyncClient -> { + assertThat(mcpAsyncClient).isNotNull(); + }); } @Test void testListPromptsWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.listPrompts(null)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before listing prompts")) - .verify(); + verifyInitializationTimeout(client -> client.listPrompts(null), "listing " + "prompts"); } @Test void testListPrompts() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(null))) - .consumeNextWith(prompts -> { - assertThat(prompts).isNotNull().satisfies(result -> { - assertThat(result.prompts()).isNotNull(); - - if (!result.prompts().isEmpty()) { - Prompt firstPrompt = result.prompts().get(0); - assertThat(firstPrompt.name()).isNotNull(); - assertThat(firstPrompt.description()).isNotNull(); - } - }); - }) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listPrompts(null))) + .consumeNextWith(prompts -> { + assertThat(prompts).isNotNull().satisfies(result -> { + assertThat(result.prompts()).isNotNull(); + + if (!result.prompts().isEmpty()) { + Prompt firstPrompt = result.prompts().get(0); + assertThat(firstPrompt.name()).isNotNull(); + assertThat(firstPrompt.description()).isNotNull(); + } + }); + }) + .verifyComplete(); + }); } @Test void testGetPromptWithoutInitialization() { GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - - StepVerifier.create(mcpAsyncClient.getPrompt(request)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before getting prompts")) - .verify(); + verifyInitializationTimeout(client -> client.getPrompt(request), "getting " + "prompts"); } @Test void testGetPrompt() { - GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of()); - - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.getPrompt(request))) - .consumeNextWith(prompt -> { - assertThat(prompt).isNotNull().satisfies(result -> { - assertThat(result.messages()).isNotEmpty(); - assertThat(result.messages()).hasSize(1); - }); - }) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier + .create(mcpAsyncClient.initialize() + .then(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of())))) + .consumeNextWith(prompt -> { + assertThat(prompt).isNotNull().satisfies(result -> { + assertThat(result.messages()).isNotEmpty(); + assertThat(result.messages()).hasSize(1); + }); + }) + .verifyComplete(); + }); } @Test void testRootsListChangedWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.rootsListChangedNotification()) - .expectErrorMatches(error -> error instanceof McpError && error.getMessage() - .equals("Client must be initialized before sending roots list changed notification")) - .verify(); + verifyInitializationTimeout(client -> client.rootsListChangedNotification(), + "sending roots list changed notification"); } @Test void testRootsListChanged() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification())) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.rootsListChangedNotification())) + .verifyComplete(); + }); } @Test void testInitializeWithRootsListProviders() { - var transport = createMcpTransport(); - - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .roots(new Root("file:///test/path", "test-root")) - .build(); - - StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), + client -> { + StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + }); } @Test void testAddRoot() { - Root newRoot = new Root("file:///new/test/path", "new-test-root"); - - StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + Root newRoot = new Root("file:///new/test/path", "new-test-root"); + StepVerifier.create(mcpAsyncClient.addRoot(newRoot)).verifyComplete(); + }); } @Test void testAddRootWithNullValue() { - StepVerifier.create(mcpAsyncClient.addRoot(null)) - .expectErrorMatches(error -> error.getMessage().contains("Root must not be null")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.addRoot(null)) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class).hasMessage("Root must not be null")) + .verify(); + }); } @Test void testRemoveRoot() { - Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + withClient(createMcpTransport(), mcpAsyncClient -> { + Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + StepVerifier.create(mcpAsyncClient.addRoot(root)).verifyComplete(); - StepVerifier.create(mcpAsyncClient.addRoot(root).then(mcpAsyncClient.removeRoot(root.uri()))).verifyComplete(); + StepVerifier.create(mcpAsyncClient.removeRoot(root.uri())).verifyComplete(); + }); } @Test void testRemoveNonExistentRoot() { - StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) - .expectErrorMatches(error -> error.getMessage().contains("Root with uri 'nonexistent-uri' not found")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.removeRoot("nonexistent-uri")) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Root with uri 'nonexistent-uri' not found")) + .verify(); + }); } @Test @Disabled void testReadResource() { - StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - StepVerifier.create(mcpAsyncClient.readResource(firstResource)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.contents()).isNotNull(); - }).verifyComplete(); - } - }).verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); + StepVerifier.create(mcpAsyncClient.readResource(firstResource)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.contents()).isNotNull(); + }).verifyComplete(); + } + }).verifyComplete(); + }); } @Test void testListResourceTemplatesWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.listResourceTemplates()) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before listing resource templates")) - .verify(); + verifyInitializationTimeout(client -> client.listResourceTemplates(), "listing resource templates"); } @Test void testListResourceTemplates() { - StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) - .consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); - }) - .verifyComplete(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.listResourceTemplates())) + .consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }) + .verifyComplete(); + }); } // @Test void testResourceSubscription() { - StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); - // Test subscribe - StepVerifier.create(mcpAsyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) - .verifyComplete(); + // Test subscribe + StepVerifier.create(mcpAsyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) + .verifyComplete(); - // Test unsubscribe - StepVerifier.create(mcpAsyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) - .verifyComplete(); - } - }).verifyComplete(); + // Test unsubscribe + StepVerifier.create(mcpAsyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) + .verifyComplete(); + } + }).verifyComplete(); + }); } @Test @@ -354,36 +395,44 @@ void testNotificationHandlers() { AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) - .resourcesChangeConsumer(resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) - .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))) - .build(); - - StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + withClient(createMcpTransport(), + builder -> builder + .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) + .resourcesChangeConsumer( + resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) + .promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))), + mcpAsyncClient -> { + + var transport = createMcpTransport(); + var client = McpClient.async(transport) + .requestTimeout(getRequestTimeout()) + .toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true))) + .resourcesChangeConsumer( + resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true))) + .promptsChangeConsumer( + prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true))) + .build(); + + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + }); } @Test void testInitializeWithSamplingCapability() { - var transport = createMcpTransport(); - - var capabilities = ClientCapabilities.builder().sampling().build(); - - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .capabilities(capabilities) - .sampling(request -> Mono.just(CreateMessageResult.builder().message("test").model("test-model").build())) + ClientCapabilities capabilities = ClientCapabilities.builder().sampling().build(); + CreateMessageResult createMessageResult = CreateMessageResult.builder() + .message("test") + .model("test-model") .build(); - - StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).sampling(request -> Mono.just(createMessageResult)), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + }); } @Test void testInitializeWithAllCapabilities() { - var transport = createMcpTransport(); - var capabilities = ClientCapabilities.builder() .experimental(Map.of("feature", "test")) .roots(true) @@ -392,18 +441,14 @@ void testInitializeWithAllCapabilities() { Function> samplingHandler = request -> Mono .just(CreateMessageResult.builder().message("test").model("test-model").build()); - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .capabilities(capabilities) - .sampling(samplingHandler) - .build(); - StepVerifier.create(client.initialize()).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.capabilities()).isNotNull(); - }).verifyComplete(); + withClient(createMcpTransport(), builder -> builder.capabilities(capabilities).sampling(samplingHandler), + client -> - StepVerifier.create(client.closeGracefully()).verifyComplete(); + StepVerifier.create(client.initialize()).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.capabilities()).isNotNull(); + }).verifyComplete()); } // --------------------------------------- @@ -412,43 +457,46 @@ void testInitializeWithAllCapabilities() { @Test void testLoggingLevelsWithoutInitialization() { - StepVerifier.create(mcpAsyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG)) - .expectErrorMatches(error -> error instanceof McpError - && error.getMessage().equals("Client must be initialized before setting logging level")) - .verify(); + verifyInitializationTimeout(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), + "setting logging level"); } @Test void testLoggingLevels() { - Mono testAllLevels = mcpAsyncClient.initialize().then(Mono.defer(() -> { - Mono chain = Mono.empty(); - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - chain = chain.then(mcpAsyncClient.setLoggingLevel(level)); - } - return chain; - })); + withClient(createMcpTransport(), mcpAsyncClient -> { + Mono testAllLevels = mcpAsyncClient.initialize().then(Mono.defer(() -> { + Mono chain = Mono.empty(); + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + chain = chain.then(mcpAsyncClient.setLoggingLevel(level)); + } + return chain; + })); - StepVerifier.create(testAllLevels).verifyComplete(); + StepVerifier.create(testAllLevels).verifyComplete(); + }); } @Test void testLoggingConsumer() { AtomicBoolean logReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - var client = McpClient.async(transport) - .requestTimeout(getRequestTimeout()) - .loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))) - .build(); + withClient(createMcpTransport(), + builder -> builder.loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true))), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + StepVerifier.create(client.closeGracefully()).verifyComplete(); + + }); - StepVerifier.create(client.initialize().then(client.closeGracefully())).verifyComplete(); } @Test void testLoggingWithNullNotification() { - StepVerifier.create(mcpAsyncClient.setLoggingLevel(null)) - .expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null")) - .verify(); + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.setLoggingLevel(null)) + .expectErrorMatches(error -> error.getMessage().contains("Logging level must not be null")) + .verify(); + }); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 52a0138f..1c042bf2 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -7,6 +7,9 @@ import java.time.Duration; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; import io.modelcontextprotocol.spec.ClientMcpTransport; import io.modelcontextprotocol.spec.McpError; @@ -27,6 +30,10 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; @@ -41,12 +48,8 @@ // KEEP IN SYNC with the class in mcp-test module public abstract class AbstractMcpSyncClientTests { - private McpSyncClient mcpSyncClient; - private static final String TEST_MESSAGE = "Hello MCP Spring AI!"; - protected ClientMcpTransport mcpTransport; - abstract protected ClientMcpTransport createMcpTransport(); protected void onStart() { @@ -63,254 +66,322 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } - @BeforeEach - void setUp() { - onStart(); - this.mcpTransport = createMcpTransport(); + McpSyncClient client(ClientMcpTransport transport) { + return client(transport, Function.identity()); + } + + McpSyncClient client(ClientMcpTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { - mcpSyncClient = McpClient.sync(mcpTransport) + McpClient.SyncSpec builder = McpClient.sync(transport) .requestTimeout(getRequestTimeout()) .initializationTimeout(getInitializationTimeout()) - .capabilities(ClientCapabilities.builder().roots(true).build()) - .build(); + .capabilities(ClientCapabilities.builder().roots(true).build()); + builder = customizer.apply(builder); + client.set(builder.build()); }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(ClientMcpTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(ClientMcpTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + assertThat(client.closeGracefully()).isTrue(); + } + } + + @BeforeEach + void setUp() { + onStart(); + } @AfterEach void tearDown() { - if (mcpSyncClient != null) { - assertThatCode(() -> mcpSyncClient.close()).doesNotThrowAnyException(); - } onClose(); } + static final Object DUMMY_RETURN_VALUE = new Object(); + + void verifyNotificationTimesOut(Consumer operation, String action) { + verifyCallTimesOut(client -> { + operation.accept(client); + return DUMMY_RETURN_VALUE; + }, action); + } + + void verifyCallTimesOut(Function operation, String action) { + withClient(createMcpTransport(), mcpSyncClient -> { + // This scheduler is not replaced by virtual time scheduler + Scheduler customScheduler = Schedulers.newBoundedElastic(1, 1, "actualBoundedElastic"); + + StepVerifier.withVirtualTime(() -> Mono.fromSupplier(() -> operation.apply(mcpSyncClient)) + // offload the blocking call to the real scheduler + .subscribeOn(customScheduler)) + .expectSubscription() + .thenAwait(getInitializationTimeout()) + .consumeErrorWith(e -> assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized before " + action)) + .verify(); + + customScheduler.dispose(); + }); + } + @Test void testConstructorWithInvalidArguments() { assertThatThrownBy(() -> McpClient.sync(null).build()).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); - assertThatThrownBy(() -> McpClient.sync(mcpTransport).requestTimeout(null).build()) + assertThatThrownBy(() -> McpClient.sync(createMcpTransport()).requestTimeout(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Request timeout must not be null"); } @Test void testListToolsWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.listTools(null)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing tools"); + verifyCallTimesOut(client -> client.listTools(null), "listing tools"); } @Test void testListTools() { - mcpSyncClient.initialize(); - ListToolsResult tools = mcpSyncClient.listTools(null); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListToolsResult tools = mcpSyncClient.listTools(null); - assertThat(tools).isNotNull().satisfies(result -> { - assertThat(result.tools()).isNotNull().isNotEmpty(); + assertThat(tools).isNotNull().satisfies(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); - Tool firstTool = result.tools().get(0); - assertThat(firstTool.name()).isNotNull(); - assertThat(firstTool.description()).isNotNull(); + Tool firstTool = result.tools().get(0); + assertThat(firstTool.name()).isNotNull(); + assertThat(firstTool.description()).isNotNull(); + }); }); } @Test void testCallToolsWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4)))) - .isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before calling tools"); + verifyCallTimesOut(client -> client.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))), + "calling tools"); } @Test void testCallTools() { - mcpSyncClient.initialize(); - CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + CallToolResult toolResult = mcpSyncClient.callTool(new CallToolRequest("add", Map.of("a", 3, "b", 4))); - assertThat(toolResult).isNotNull().satisfies(result -> { + assertThat(toolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).hasSize(1); + assertThat(result.content()).hasSize(1); - TextContent content = (TextContent) result.content().get(0); + TextContent content = (TextContent) result.content().get(0); - assertThat(content).isNotNull(); - assertThat(content.text()).isNotNull(); - assertThat(content.text()).contains("7"); + assertThat(content).isNotNull(); + assertThat(content.text()).isNotNull(); + assertThat(content.text()).contains("7"); + }); }); } @Test void testPingWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.ping()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before pinging the server"); + verifyCallTimesOut(client -> client.ping(), "pinging the server"); } @Test void testPing() { - mcpSyncClient.initialize(); - assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + assertThatCode(() -> mcpSyncClient.ping()).doesNotThrowAnyException(); + }); } @Test void testCallToolWithoutInitialization() { CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - - assertThatThrownBy(() -> mcpSyncClient.callTool(callToolRequest)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before calling tools"); + verifyCallTimesOut(client -> client.callTool(callToolRequest), "calling tools"); } @Test void testCallTool() { - mcpSyncClient.initialize(); - CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", TEST_MESSAGE)); - CallToolResult callToolResult = mcpSyncClient.callTool(callToolRequest); + CallToolResult callToolResult = mcpSyncClient.callTool(callToolRequest); - assertThat(callToolResult).isNotNull().satisfies(result -> { - assertThat(result.content()).isNotNull(); - assertThat(result.isError()).isNull(); + assertThat(callToolResult).isNotNull().satisfies(result -> { + assertThat(result.content()).isNotNull(); + assertThat(result.isError()).isNull(); + }); }); } @Test void testCallToolWithInvalidTool() { - CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE)); + withClient(createMcpTransport(), mcpSyncClient -> { + CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", TEST_MESSAGE)); - assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class); + assertThatThrownBy(() -> mcpSyncClient.callTool(invalidRequest)).isInstanceOf(Exception.class); + }); } @Test void testRootsListChangedWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.rootsListChangedNotification()).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before sending roots list changed notification"); + verifyNotificationTimesOut(client -> client.rootsListChangedNotification(), + "sending roots list changed notification"); } @Test void testRootsListChanged() { - mcpSyncClient.initialize(); - assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + assertThatCode(() -> mcpSyncClient.rootsListChangedNotification()).doesNotThrowAnyException(); + }); } @Test void testListResourcesWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.listResources(null)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resources"); + verifyCallTimesOut(client -> client.listResources(null), "listing resources"); } @Test void testListResources() { - mcpSyncClient.initialize(); - ListResourcesResult resources = mcpSyncClient.listResources(null); - - assertThat(resources).isNotNull().satisfies(result -> { - assertThat(result.resources()).isNotNull(); - - if (!result.resources().isEmpty()) { - Resource firstResource = result.resources().get(0); - assertThat(firstResource.uri()).isNotNull(); - assertThat(firstResource.name()).isNotNull(); - } + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourcesResult resources = mcpSyncClient.listResources(null); + + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + + if (!result.resources().isEmpty()) { + Resource firstResource = result.resources().get(0); + assertThat(firstResource.uri()).isNotNull(); + assertThat(firstResource.name()).isNotNull(); + } + }); }); } @Test void testClientSessionState() { - assertThat(mcpSyncClient).isNotNull(); + withClient(createMcpTransport(), mcpSyncClient -> { + assertThat(mcpSyncClient).isNotNull(); + }); } @Test void testInitializeWithRootsListProviders() { - var transport = createMcpTransport(); - - var client = McpClient.sync(transport) - .requestTimeout(getRequestTimeout()) - .roots(new Root("file:///test/path", "test-root")) - .build(); + withClient(createMcpTransport(), builder -> builder.roots(new Root("file:///test/path", "test-root")), + mcpSyncClient -> { - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); + assertThatCode(() -> { + mcpSyncClient.initialize(); + mcpSyncClient.close(); + }).doesNotThrowAnyException(); + }); } @Test void testAddRoot() { - Root newRoot = new Root("file:///new/test/path", "new-test-root"); - assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + Root newRoot = new Root("file:///new/test/path", "new-test-root"); + assertThatCode(() -> mcpSyncClient.addRoot(newRoot)).doesNotThrowAnyException(); + }); } @Test void testAddRootWithNullValue() { - assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null"); + withClient(createMcpTransport(), mcpSyncClient -> { + assertThatThrownBy(() -> mcpSyncClient.addRoot(null)).hasMessageContaining("Root must not be null"); + }); } @Test void testRemoveRoot() { - Root root = new Root("file:///test/path/to/remove", "root-to-remove"); - assertThatCode(() -> { - mcpSyncClient.addRoot(root); - mcpSyncClient.removeRoot(root.uri()); - }).doesNotThrowAnyException(); + withClient(createMcpTransport(), mcpSyncClient -> { + Root root = new Root("file:///test/path/to/remove", "root-to-remove"); + assertThatCode(() -> { + mcpSyncClient.addRoot(root); + mcpSyncClient.removeRoot(root.uri()); + }).doesNotThrowAnyException(); + }); } @Test void testRemoveNonExistentRoot() { - assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri")) - .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); + withClient(createMcpTransport(), mcpSyncClient -> { + assertThatThrownBy(() -> mcpSyncClient.removeRoot("nonexistent-uri")) + .hasMessageContaining("Root with uri 'nonexistent-uri' not found"); + }); } @Test void testReadResourceWithoutInitialization() { - assertThatThrownBy(() -> { - Resource resource = new Resource("test://uri", "Test Resource", null, null, null); - mcpSyncClient.readResource(resource); - }).isInstanceOf(McpError.class).hasMessage("Client must be initialized before reading resources"); + Resource resource = new Resource("test://uri", "Test Resource", null, null, null); + verifyCallTimesOut(client -> client.readResource(resource), "reading resources"); } @Test void testReadResource() { - mcpSyncClient.initialize(); - ListResourcesResult resources = mcpSyncClient.listResources(null); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourcesResult resources = mcpSyncClient.listResources(null); - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - ReadResourceResult result = mcpSyncClient.readResource(firstResource); + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); + ReadResourceResult result = mcpSyncClient.readResource(firstResource); - assertThat(result).isNotNull(); - assertThat(result.contents()).isNotNull(); - } + assertThat(result).isNotNull(); + assertThat(result.contents()).isNotNull(); + } + }); } @Test void testListResourceTemplatesWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.listResourceTemplates(null)).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before listing resource templates"); + verifyCallTimesOut(client -> client.listResourceTemplates(null), "listing resource templates"); } @Test void testListResourceTemplates() { - mcpSyncClient.initialize(); - ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(null); + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(null); - assertThat(result).isNotNull(); - assertThat(result.resourceTemplates()).isNotNull(); + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).isNotNull(); + }); } // @Test void testResourceSubscription() { - ListResourcesResult resources = mcpSyncClient.listResources(null); + withClient(createMcpTransport(), mcpSyncClient -> { + ListResourcesResult resources = mcpSyncClient.listResources(null); - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); + if (!resources.resources().isEmpty()) { + Resource firstResource = resources.resources().get(0); - // Test subscribe - assertThatCode(() -> mcpSyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) - .doesNotThrowAnyException(); + // Test subscribe + assertThatCode(() -> mcpSyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) + .doesNotThrowAnyException(); - // Test unsubscribe - assertThatCode(() -> mcpSyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) - .doesNotThrowAnyException(); - } + // Test unsubscribe + assertThatCode(() -> mcpSyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) + .doesNotThrowAnyException(); + } + }); } @Test @@ -319,18 +390,17 @@ void testNotificationHandlers() { AtomicBoolean resourcesNotificationReceived = new AtomicBoolean(false); AtomicBoolean promptsNotificationReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - var client = McpClient.sync(transport) - .requestTimeout(getRequestTimeout()) - .toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) - .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) - .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)) - .build(); + withClient(createMcpTransport(), + builder -> builder.toolsChangeConsumer(tools -> toolsNotificationReceived.set(true)) + .resourcesChangeConsumer(resources -> resourcesNotificationReceived.set(true)) + .promptsChangeConsumer(prompts -> promptsNotificationReceived.set(true)), + client -> { - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); + assertThatCode(() -> { + client.initialize(); + client.close(); + }).doesNotThrowAnyException(); + }); } // --------------------------------------- @@ -339,40 +409,37 @@ void testNotificationHandlers() { @Test void testLoggingLevelsWithoutInitialization() { - assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG)) - .isInstanceOf(McpError.class) - .hasMessage("Client must be initialized before setting logging level"); + verifyNotificationTimesOut(client -> client.setLoggingLevel(McpSchema.LoggingLevel.DEBUG), + "setting logging level"); } @Test void testLoggingLevels() { - mcpSyncClient.initialize(); - // Test all logging levels - for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { - assertThatCode(() -> mcpSyncClient.setLoggingLevel(level)).doesNotThrowAnyException(); - } + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + // Test all logging levels + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + assertThatCode(() -> mcpSyncClient.setLoggingLevel(level)).doesNotThrowAnyException(); + } + }); } @Test void testLoggingConsumer() { AtomicBoolean logReceived = new AtomicBoolean(false); - var transport = createMcpTransport(); - - var client = McpClient.sync(transport) - .requestTimeout(getRequestTimeout()) - .loggingConsumer(notification -> logReceived.set(true)) - .build(); - - assertThatCode(() -> { - client.initialize(); - client.close(); - }).doesNotThrowAnyException(); + withClient(createMcpTransport(), builder -> builder.requestTimeout(getRequestTimeout()) + .loggingConsumer(notification -> logReceived.set(true)), client -> { + assertThatCode(() -> { + client.initialize(); + client.close(); + }).doesNotThrowAnyException(); + }); } @Test void testLoggingWithNullNotification() { - assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null)) - .hasMessageContaining("Logging level must not be null"); + withClient(createMcpTransport(), mcpSyncClient -> assertThatThrownBy(() -> mcpSyncClient.setLoggingLevel(null)) + .hasMessageContaining("Logging level must not be null")); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java index 6d759b4b..ebf10b9a 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java @@ -5,6 +5,8 @@ package io.modelcontextprotocol.client; import java.time.Duration; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import io.modelcontextprotocol.client.transport.ServerParameters; @@ -13,6 +15,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; import static org.assertj.core.api.Assertions.assertThat; @@ -35,15 +38,26 @@ protected ClientMcpTransport createMcpTransport() { } @Test - void customErrorHandlerShouldReceiveErrors() { + void customErrorHandlerShouldReceiveErrors() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); AtomicReference receivedError = new AtomicReference<>(); - ((StdioClientTransport) mcpTransport).setStdErrorHandler(error -> receivedError.set(error)); + ClientMcpTransport transport = createMcpTransport(); + StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); + + ((StdioClientTransport) transport).setStdErrorHandler(error -> { + receivedError.set(error); + latch.countDown(); + }); String errorMessage = "Test error"; - ((StdioClientTransport) mcpTransport).getErrorSink().emitNext(errorMessage, Sinks.EmitFailureHandler.FAIL_FAST); + ((StdioClientTransport) transport).getErrorSink().emitNext(errorMessage, Sinks.EmitFailureHandler.FAIL_FAST); + + assertThat(latch.await(5, TimeUnit.SECONDS)).isTrue(); assertThat(receivedError.get()).isNotNull().isEqualTo(errorMessage); + + StepVerifier.create(transport.closeGracefully()).expectComplete().verify(Duration.ofSeconds(5)); } protected Duration getInitializationTimeout() {