Skip to content

Commit 109aab2

Browse files
authored
feat(client): Improve initialization state handling in McpAsyncClient (#39)
- Add proper initialization state tracking using AtomicBoolean and Sinks - Implement timeout handling for requests requiring initialization - Ensure all client methods verify initialization state before proceeding - Fix rootsListChangedNotification to check initialization state - Improve error messages for uninitialized client operations - improve JavaDoc - Add tests to verify proper error handling for uninitialized clients - Replace hardcoded timeout constants with configurable getTimeoutDuration() method - Remove automatic initialization in setUp methods to allow explicit testing Signed-off-by: Christian Tzolov <[email protected]>
1 parent 4eff00c commit 109aab2

File tree

9 files changed

+613
-242
lines changed

9 files changed

+613
-242
lines changed

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java

+7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
package io.modelcontextprotocol.client;
66

7+
import java.time.Duration;
8+
79
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
810
import io.modelcontextprotocol.spec.ClientMcpTransport;
911
import org.junit.jupiter.api.Timeout;
@@ -46,4 +48,9 @@ public void onClose() {
4648
container.stop();
4749
}
4850

51+
@Override
52+
protected Duration getTimeoutDuration() {
53+
return Duration.ofMillis(300);
54+
}
55+
4956
}

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java

+7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
package io.modelcontextprotocol.client;
66

7+
import java.time.Duration;
8+
79
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
810
import io.modelcontextprotocol.spec.ClientMcpTransport;
911
import org.junit.jupiter.api.Timeout;
@@ -46,4 +48,9 @@ protected void onClose() {
4648
container.stop();
4749
}
4850

51+
@Override
52+
protected Duration getTimeoutDuration() {
53+
return Duration.ofMillis(300);
54+
}
55+
4956
}

mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java

+89-9
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import java.util.function.Function;
1111

1212
import io.modelcontextprotocol.spec.ClientMcpTransport;
13+
import io.modelcontextprotocol.spec.McpError;
1314
import io.modelcontextprotocol.spec.McpSchema;
1415
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
1516
import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
@@ -47,8 +48,6 @@ public abstract class AbstractMcpAsyncClientTests {
4748

4849
protected ClientMcpTransport mcpTransport;
4950

50-
private static final Duration TIMEOUT = Duration.ofSeconds(20);
51-
5251
private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!";
5352

5453
abstract protected ClientMcpTransport createMcpTransport();
@@ -59,17 +58,20 @@ protected void onStart() {
5958
protected void onClose() {
6059
}
6160

61+
protected Duration getTimeoutDuration() {
62+
return Duration.ofSeconds(2);
63+
}
64+
6265
@BeforeEach
6366
void setUp() {
6467
onStart();
6568
this.mcpTransport = createMcpTransport();
6669

6770
assertThatCode(() -> {
6871
mcpAsyncClient = McpClient.async(mcpTransport)
69-
.requestTimeout(TIMEOUT)
72+
.requestTimeout(getTimeoutDuration())
7073
.capabilities(ClientCapabilities.builder().roots(true).build())
7174
.build();
72-
mcpAsyncClient.initialize().block(Duration.ofSeconds(10));
7375
}).doesNotThrowAnyException();
7476
}
7577

@@ -92,8 +94,16 @@ void testConstructorWithInvalidArguments() {
9294
.hasMessage("Request timeout must not be null");
9395
}
9496

97+
@Test
98+
void testListToolsWithoutInitialization() {
99+
assertThatThrownBy(() -> mcpAsyncClient.listTools(null).block()).isInstanceOf(McpError.class)
100+
.hasMessage("Client must be initialized before listing tools");
101+
}
102+
95103
@Test
96104
void testListTools() {
105+
mcpAsyncClient.initialize().block(Duration.ofSeconds(10));
106+
97107
StepVerifier.create(mcpAsyncClient.listTools(null)).consumeNextWith(result -> {
98108
assertThat(result.tools()).isNotNull().isNotEmpty();
99109

@@ -103,13 +113,30 @@ void testListTools() {
103113
}).verifyComplete();
104114
}
105115

116+
@Test
117+
void testPingWithoutInitialization() {
118+
assertThatThrownBy(() -> mcpAsyncClient.ping().block()).isInstanceOf(McpError.class)
119+
.hasMessage("Client must be initialized before pinging the server");
120+
}
121+
106122
@Test
107123
void testPing() {
124+
mcpAsyncClient.initialize().block(Duration.ofSeconds(10));
108125
assertThatCode(() -> mcpAsyncClient.ping().block()).doesNotThrowAnyException();
109126
}
110127

128+
@Test
129+
void testCallToolWithoutInitialization() {
130+
CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE));
131+
132+
assertThatThrownBy(() -> mcpAsyncClient.callTool(callToolRequest).block()).isInstanceOf(McpError.class)
133+
.hasMessage("Client must be initialized before calling tools");
134+
}
135+
111136
@Test
112137
void testCallTool() {
138+
mcpAsyncClient.initialize().block(Duration.ofSeconds(10));
139+
113140
CallToolRequest callToolRequest = new CallToolRequest("echo", Map.of("message", ECHO_TEST_MESSAGE));
114141

115142
StepVerifier.create(mcpAsyncClient.callTool(callToolRequest)).consumeNextWith(callToolResult -> {
@@ -122,13 +149,23 @@ void testCallTool() {
122149

123150
@Test
124151
void testCallToolWithInvalidTool() {
152+
mcpAsyncClient.initialize().block(Duration.ofSeconds(10));
153+
125154
CallToolRequest invalidRequest = new CallToolRequest("nonexistent_tool", Map.of("message", ECHO_TEST_MESSAGE));
126155

127156
assertThatThrownBy(() -> mcpAsyncClient.callTool(invalidRequest).block()).isInstanceOf(Exception.class);
128157
}
129158

159+
@Test
160+
void testListResourcesWithoutInitialization() {
161+
assertThatThrownBy(() -> mcpAsyncClient.listResources(null).block()).isInstanceOf(McpError.class)
162+
.hasMessage("Client must be initialized before listing resources");
163+
}
164+
130165
@Test
131166
void testListResources() {
167+
mcpAsyncClient.initialize().block(Duration.ofSeconds(10));
168+
132169
StepVerifier.create(mcpAsyncClient.listResources(null)).consumeNextWith(resources -> {
133170
assertThat(resources).isNotNull().satisfies(result -> {
134171
assertThat(result.resources()).isNotNull();
@@ -147,8 +184,16 @@ void testMcpAsyncClientState() {
147184
assertThat(mcpAsyncClient).isNotNull();
148185
}
149186

187+
@Test
188+
void testListPromptsWithoutInitialization() {
189+
assertThatThrownBy(() -> mcpAsyncClient.listPrompts(null).block()).isInstanceOf(McpError.class)
190+
.hasMessage("Client must be initialized before listing prompts");
191+
}
192+
150193
@Test
151194
void testListPrompts() {
195+
mcpAsyncClient.initialize().block(Duration.ofSeconds(10));
196+
152197
StepVerifier.create(mcpAsyncClient.listPrompts(null)).consumeNextWith(prompts -> {
153198
assertThat(prompts).isNotNull().satisfies(result -> {
154199
assertThat(result.prompts()).isNotNull();
@@ -162,8 +207,18 @@ void testListPrompts() {
162207
}).verifyComplete();
163208
}
164209

210+
@Test
211+
void testGetPromptWithoutInitialization() {
212+
GetPromptRequest request = new GetPromptRequest("simple_prompt", Map.of());
213+
214+
assertThatThrownBy(() -> mcpAsyncClient.getPrompt(request).block()).isInstanceOf(McpError.class)
215+
.hasMessage("Client must be initialized before getting prompts");
216+
}
217+
165218
@Test
166219
void testGetPrompt() {
220+
mcpAsyncClient.initialize().block(Duration.ofSeconds(10));
221+
167222
StepVerifier.create(mcpAsyncClient.getPrompt(new GetPromptRequest("simple_prompt", Map.of())))
168223
.consumeNextWith(prompt -> {
169224
assertThat(prompt).isNotNull().satisfies(result -> {
@@ -174,8 +229,16 @@ void testGetPrompt() {
174229
.verifyComplete();
175230
}
176231

232+
@Test
233+
void testRootsListChangedWithoutInitialization() {
234+
assertThatThrownBy(() -> mcpAsyncClient.rootsListChangedNotification().block()).isInstanceOf(McpError.class)
235+
.hasMessage("Client must be initialized before sending roots list changed notification");
236+
}
237+
177238
@Test
178239
void testRootsListChanged() {
240+
mcpAsyncClient.initialize().block(Duration.ofSeconds(10));
241+
179242
assertThatCode(() -> mcpAsyncClient.rootsListChangedNotification().block()).doesNotThrowAnyException();
180243
}
181244

@@ -184,7 +247,7 @@ void testInitializeWithRootsListProviders() {
184247
var transport = createMcpTransport();
185248

186249
var client = McpClient.async(transport)
187-
.requestTimeout(TIMEOUT)
250+
.requestTimeout(getTimeoutDuration())
188251
.roots(new Root("file:///test/path", "test-root"))
189252
.build();
190253

@@ -233,8 +296,16 @@ void testReadResource() {
233296
}).verifyComplete();
234297
}
235298

299+
@Test
300+
void testListResourceTemplatesWithoutInitialization() {
301+
assertThatThrownBy(() -> mcpAsyncClient.listResourceTemplates().block()).isInstanceOf(McpError.class)
302+
.hasMessage("Client must be initialized before listing resource templates");
303+
}
304+
236305
@Test
237306
void testListResourceTemplates() {
307+
mcpAsyncClient.initialize().block(Duration.ofSeconds(10));
308+
238309
StepVerifier.create(mcpAsyncClient.listResourceTemplates()).consumeNextWith(result -> {
239310
assertThat(result).isNotNull();
240311
assertThat(result.resourceTemplates()).isNotNull();
@@ -266,7 +337,7 @@ void testNotificationHandlers() {
266337

267338
var transport = createMcpTransport();
268339
var client = McpClient.async(transport)
269-
.requestTimeout(TIMEOUT)
340+
.requestTimeout(getTimeoutDuration())
270341
.toolsChangeConsumer(tools -> Mono.fromRunnable(() -> toolsNotificationReceived.set(true)))
271342
.resourcesChangeConsumer(resources -> Mono.fromRunnable(() -> resourcesNotificationReceived.set(true)))
272343
.promptsChangeConsumer(prompts -> Mono.fromRunnable(() -> promptsNotificationReceived.set(true)))
@@ -285,7 +356,7 @@ void testInitializeWithSamplingCapability() {
285356
var capabilities = ClientCapabilities.builder().sampling().build();
286357

287358
var client = McpClient.async(transport)
288-
.requestTimeout(TIMEOUT)
359+
.requestTimeout(getTimeoutDuration())
289360
.capabilities(capabilities)
290361
.sampling(request -> Mono.just(CreateMessageResult.builder().message("test").model("test-model").build()))
291362
.build();
@@ -309,7 +380,7 @@ void testInitializeWithAllCapabilities() {
309380
Function<CreateMessageRequest, Mono<CreateMessageResult>> samplingHandler = request -> Mono
310381
.just(CreateMessageResult.builder().message("test").model("test-model").build());
311382
var client = McpClient.async(transport)
312-
.requestTimeout(TIMEOUT)
383+
.requestTimeout(getTimeoutDuration())
313384
.capabilities(capabilities)
314385
.sampling(samplingHandler)
315386
.build();
@@ -326,8 +397,17 @@ void testInitializeWithAllCapabilities() {
326397
// Logging Tests
327398
// ---------------------------------------
328399

400+
@Test
401+
void testLoggingLevelsWithoutInitialization() {
402+
assertThatThrownBy(() -> mcpAsyncClient.setLoggingLevel(McpSchema.LoggingLevel.DEBUG).block())
403+
.isInstanceOf(McpError.class)
404+
.hasMessage("Client must be initialized before setting logging level");
405+
}
406+
329407
@Test
330408
void testLoggingLevels() {
409+
mcpAsyncClient.initialize().block(Duration.ofSeconds(10));
410+
331411
// Test all logging levels
332412
for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) {
333413
StepVerifier.create(mcpAsyncClient.setLoggingLevel(level)).verifyComplete();
@@ -340,7 +420,7 @@ void testLoggingConsumer() {
340420
var transport = createMcpTransport();
341421

342422
var client = McpClient.async(transport)
343-
.requestTimeout(TIMEOUT)
423+
.requestTimeout(getTimeoutDuration())
344424
.loggingConsumer(notification -> Mono.fromRunnable(() -> logReceived.set(true)))
345425
.build();
346426

0 commit comments

Comments
 (0)