Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(client): Improve initialization state handling in McpAsyncClient #39

Merged
merged 6 commits into from
Mar 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

package io.modelcontextprotocol.client;

import java.time.Duration;

import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
import io.modelcontextprotocol.spec.ClientMcpTransport;
import org.junit.jupiter.api.Timeout;
Expand Down Expand Up @@ -46,4 +48,9 @@ public void onClose() {
container.stop();
}

@Override
protected Duration getTimeoutDuration() {
return Duration.ofMillis(300);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

package io.modelcontextprotocol.client;

import java.time.Duration;

import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
import io.modelcontextprotocol.spec.ClientMcpTransport;
import org.junit.jupiter.api.Timeout;
Expand Down Expand Up @@ -46,4 +48,9 @@ protected void onClose() {
container.stop();
}

@Override
protected Duration getTimeoutDuration() {
return Duration.ofMillis(300);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.function.Function;

import io.modelcontextprotocol.spec.ClientMcpTransport;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
Expand Down Expand Up @@ -47,8 +48,6 @@ public abstract class AbstractMcpAsyncClientTests {

protected ClientMcpTransport mcpTransport;

private static final Duration TIMEOUT = Duration.ofSeconds(20);

private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!";

abstract protected ClientMcpTransport createMcpTransport();
Expand All @@ -59,17 +58,20 @@ protected void onStart() {
protected void onClose() {
}

protected Duration getTimeoutDuration() {
return Duration.ofSeconds(2);
}

@BeforeEach
void setUp() {
onStart();
this.mcpTransport = createMcpTransport();

assertThatCode(() -> {
mcpAsyncClient = McpClient.async(mcpTransport)
.requestTimeout(TIMEOUT)
.requestTimeout(getTimeoutDuration())
.capabilities(ClientCapabilities.builder().roots(true).build())
.build();
mcpAsyncClient.initialize().block(Duration.ofSeconds(10));
}).doesNotThrowAnyException();
}

Expand All @@ -92,8 +94,16 @@ void testConstructorWithInvalidArguments() {
.hasMessage("Request timeout must not be null");
}

@Test
void testListToolsWithoutInitialization() {
assertThatThrownBy(() -> mcpAsyncClient.listTools(null).block()).isInstanceOf(McpError.class)
Copy link
Member

@chemicL chemicL Mar 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test will last at least 600ms in case of WebClient and 4s for JDK HttpClient before the exception is observed, correct? If I am correct, it would be useful to consider using StepVerifier.withVirtualTime to emulate time passing by instead of adding more seconds to the time w validate the behaviour.

.hasMessage("Client must be initialized before listing tools");
}

@Test
void testListTools() {
mcpAsyncClient.initialize().block(Duration.ofSeconds(10));

StepVerifier.create(mcpAsyncClient.listTools(null)).consumeNextWith(result -> {
assertThat(result.tools()).isNotNull().isNotEmpty();

Expand All @@ -103,13 +113,30 @@ void testListTools() {
}).verifyComplete();
}

@Test
void testPingWithoutInitialization() {
assertThatThrownBy(() -> mcpAsyncClient.ping().block()).isInstanceOf(McpError.class)
.hasMessage("Client must be initialized before pinging the server");
}

@Test
void testPing() {
mcpAsyncClient.initialize().block(Duration.ofSeconds(10));
assertThatCode(() -> mcpAsyncClient.ping().block()).doesNotThrowAnyException();
}

@Test
void testCallToolWithoutInitialization() {
CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE));

assertThatThrownBy(() -> mcpAsyncClient.callTool(callToolRequest).block()).isInstanceOf(McpError.class)
.hasMessage("Client must be initialized before calling tools");
}

@Test
void testCallTool() {
mcpAsyncClient.initialize().block(Duration.ofSeconds(10));

CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE));

StepVerifier.create(mcpAsyncClient.callTool(callToolRequest)).consumeNextWith(callToolResult -> {
Expand All @@ -122,13 +149,23 @@ void testCallTool() {

@Test
void testCallToolWithInvalidTool() {
mcpAsyncClient.initialize().block(Duration.ofSeconds(10));

CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", ECHO_TEST_MESSAGE));

assertThatThrownBy(() -> mcpAsyncClient.callTool(invalidRequest).block()).isInstanceOf(Exception.class);
}

@Test
void testListResourcesWithoutInitialization() {
assertThatThrownBy(() -> mcpAsyncClient.listResources(null).block()).isInstanceOf(McpError.class)
.hasMessage("Client must be initialized before listing resources");
}

@Test
void testListResources() {
mcpAsyncClient.initialize().block(Duration.ofSeconds(10));

StepVerifier.create(mcpAsyncClient.listResources(null)).consumeNextWith(resources -> {
assertThat(resources).isNotNull().satisfies(result -> {
assertThat(result.resources()).isNotNull();
Expand All @@ -147,8 +184,16 @@ void testMcpAsyncClientState() {
assertThat(mcpAsyncClient).isNotNull();
}

@Test
void testListPromptsWithoutInitialization() {
assertThatThrownBy(() -> mcpAsyncClient.listPrompts(null).block()).isInstanceOf(McpError.class)
.hasMessage("Client must be initialized before listing prompts");
}

@Test
void testListPrompts() {
mcpAsyncClient.initialize().block(Duration.ofSeconds(10));

StepVerifier.create(mcpAsyncClient.listPrompts(null)).consumeNextWith(prompts -> {
assertThat(prompts).isNotNull().satisfies(result -> {
assertThat(result.prompts()).isNotNull();
Expand All @@ -162,8 +207,18 @@ void testListPrompts() {
}).verifyComplete();
}

@Test
void testGetPromptWithoutInitialization() {
GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of());

assertThatThrownBy(() -> mcpAsyncClient.getPrompt(request).block()).isInstanceOf(McpError.class)
.hasMessage("Client must be initialized before getting prompts");
}

@Test
void testGetPrompt() {
mcpAsyncClient.initialize().block(Duration.ofSeconds(10));

StepVerifier.create(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of())))
.consumeNextWith(prompt -> {
assertThat(prompt).isNotNull().satisfies(result -> {
Expand All @@ -174,8 +229,16 @@ void testGetPrompt() {
.verifyComplete();
}

@Test
void testRootsListChangedWithoutInitialization() {
assertThatThrownBy(() -> mcpAsyncClient.rootsListChangedNotification().block()).isInstanceOf(McpError.class)
.hasMessage("Client must be initialized before sending roots list changed notification");
}

@Test
void testRootsListChanged() {
mcpAsyncClient.initialize().block(Duration.ofSeconds(10));

assertThatCode(() -> mcpAsyncClient.rootsListChangedNotification().block()).doesNotThrowAnyException();
}

Expand All @@ -184,7 +247,7 @@ void testInitializeWithRootsListProviders() {
var transport = createMcpTransport();

var client = McpClient.async(transport)
.requestTimeout(TIMEOUT)
.requestTimeout(getTimeoutDuration())
.roots(new Root("file:///test/path", "test-root"))
.build();

Expand Down Expand Up @@ -233,8 +296,16 @@ void testReadResource() {
}).verifyComplete();
}

@Test
void testListResourceTemplatesWithoutInitialization() {
assertThatThrownBy(() -> mcpAsyncClient.listResourceTemplates().block()).isInstanceOf(McpError.class)
.hasMessage("Client must be initialized before listing resource templates");
}

@Test
void testListResourceTemplates() {
mcpAsyncClient.initialize().block(Duration.ofSeconds(10));

StepVerifier.create(mcpAsyncClient.listResourceTemplates()).consumeNextWith(result -> {
assertThat(result).isNotNull();
assertThat(result.resourceTemplates()).isNotNull();
Expand Down Expand Up @@ -266,7 +337,7 @@ void testNotificationHandlers() {

var transport = createMcpTransport();
var client = McpClient.async(transport)
.requestTimeout(TIMEOUT)
.requestTimeout(getTimeoutDuration())
.toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true)))
.resourcesChangeConsumer(resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true)))
.promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true)))
Expand All @@ -285,7 +356,7 @@ void testInitializeWithSamplingCapability() {
var capabilities = ClientCapabilities.builder().sampling().build();

var client = McpClient.async(transport)
.requestTimeout(TIMEOUT)
.requestTimeout(getTimeoutDuration())
.capabilities(capabilities)
.sampling(request -> Mono.just(CreateMessageResult.builder().message("test").model("test-model").build()))
.build();
Expand All @@ -309,7 +380,7 @@ void testInitializeWithAllCapabilities() {
Function<CreateMessageRequest, Mono<CreateMessageResult>> samplingHandler = request -> Mono
.just(CreateMessageResult.builder().message("test").model("test-model").build());
var client = McpClient.async(transport)
.requestTimeout(TIMEOUT)
.requestTimeout(getTimeoutDuration())
.capabilities(capabilities)
.sampling(samplingHandler)
.build();
Expand All @@ -326,8 +397,17 @@ void testInitializeWithAllCapabilities() {
// Logging Tests
// ---------------------------------------

@Test
void testLoggingLevelsWithoutInitialization() {
assertThatThrownBy(() -> mcpAsyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG).block())
.isInstanceOf(McpError.class)
.hasMessage("Client must be initialized before setting logging level");
}

@Test
void testLoggingLevels() {
mcpAsyncClient.initialize().block(Duration.ofSeconds(10));

// Test all logging levels
for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) {
StepVerifier.create(mcpAsyncClient.setLoggingLevel(level)).verifyComplete();
Expand All @@ -340,7 +420,7 @@ void testLoggingConsumer() {
var transport = createMcpTransport();

var client = McpClient.async(transport)
.requestTimeout(TIMEOUT)
.requestTimeout(getTimeoutDuration())
.loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true)))
.build();

Expand Down
Loading