Skip to content

Commit f5fed84

Browse files
committed
refactor(webflux): refactor WebFluxSseIntegrationTests
Signed-off-by: Christian Tzolov <[email protected]>
1 parent c56ab94 commit f5fed84

File tree

1 file changed

+150
-142
lines changed

1 file changed

+150
-142
lines changed

Diff for: mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java

+150-142
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
1717
import io.modelcontextprotocol.server.McpServer;
1818
import io.modelcontextprotocol.server.McpServerFeatures;
19-
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport;
2019
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
2120
import io.modelcontextprotocol.spec.McpError;
2221
import io.modelcontextprotocol.spec.McpSchema;
@@ -31,9 +30,9 @@
3130
import io.modelcontextprotocol.spec.McpSchema.Tool;
3231
import org.junit.jupiter.api.AfterEach;
3332
import org.junit.jupiter.api.BeforeEach;
34-
import org.junit.jupiter.api.Test;
3533
import org.junit.jupiter.params.ParameterizedTest;
3634
import org.junit.jupiter.params.provider.ValueSource;
35+
import reactor.core.publisher.Mono;
3736
import reactor.netty.DisposableServer;
3837
import reactor.netty.http.server.HttpServer;
3938
import reactor.test.StepVerifier;
@@ -45,8 +44,8 @@
4544
import org.springframework.web.reactive.function.server.RouterFunctions;
4645

4746
import static org.assertj.core.api.Assertions.assertThat;
48-
import static org.assertj.core.api.Assertions.assertThatThrownBy;
4947
import static org.awaitility.Awaitility.await;
48+
import static org.mockito.Mockito.mock;
5049

5150
public class WebFluxSseIntegrationTests {
5251

@@ -85,109 +84,100 @@ public void after() {
8584
// ---------------------------------------
8685
// Sampling Tests
8786
// ---------------------------------------
88-
// TODO implement within a tool execution
89-
// @Test
90-
// void testCreateMessageWithoutInitialization() {
91-
// var mcpAsyncServer =
92-
// McpServer.async(mcpServerTransportProvider).serverInfo("test-server",
93-
// "1.0.0").build();
94-
//
95-
// var messages = List
96-
// .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new
97-
// McpSchema.TextContent("Test message")));
98-
// var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0);
99-
//
100-
// var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null,
101-
// McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(),
102-
// Map.of());
103-
//
104-
// StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error
105-
// -> {
106-
// assertThat(error).isInstanceOf(McpError.class)
107-
// .hasMessage("Client must be initialized. Call the initialize method first!");
108-
// });
109-
// }
110-
//
111-
// @ParameterizedTest(name = "{0} : {displayName} ")
112-
// @ValueSource(strings = { "httpclient", "webflux" })
113-
// void testCreateMessageWithoutSamplingCapabilities(String clientType) {
114-
//
115-
// var mcpAsyncServer =
116-
// McpServer.async(mcpServerTransportProvider).serverInfo("test-server",
117-
// "1.0.0").build();
118-
//
119-
// var clientBuilder = clientBulders.get(clientType);
120-
//
121-
// var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client",
122-
// "0.0.0")).build();
123-
//
124-
// InitializeResult initResult = client.initialize();
125-
// assertThat(initResult).isNotNull();
126-
//
127-
// var messages = List
128-
// .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new
129-
// McpSchema.TextContent("Test message")));
130-
// var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0);
131-
//
132-
// var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null,
133-
// McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(),
134-
// Map.of());
135-
//
136-
// StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error
137-
// -> {
138-
// assertThat(error).isInstanceOf(McpError.class)
139-
// .hasMessage("Client must be configured with sampling capabilities");
140-
// });
141-
// }
142-
//
143-
// @ParameterizedTest(name = "{0} : {displayName} ")
144-
// @ValueSource(strings = { "httpclient", "webflux" })
145-
// void testCreateMessageSuccess(String clientType) throws InterruptedException {
146-
//
147-
// var clientBuilder = clientBulders.get(clientType);
148-
//
149-
// var mcpAsyncServer =
150-
// McpServer.async(mcpServerTransportProvider).serverInfo("test-server",
151-
// "1.0.0").build();
152-
//
153-
// Function<CreateMessageRequest, CreateMessageResult> samplingHandler = request -> {
154-
// assertThat(request.messages()).hasSize(1);
155-
// assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class);
156-
//
157-
// return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test
158-
// message"), "MockModelName",
159-
// CreateMessageResult.StopReason.STOP_SEQUENCE);
160-
// };
161-
//
162-
// var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client",
163-
// "0.0.0"))
164-
// .capabilities(ClientCapabilities.builder().sampling().build())
165-
// .sampling(samplingHandler)
166-
// .build();
167-
//
168-
// InitializeResult initResult = client.initialize();
169-
// assertThat(initResult).isNotNull();
170-
//
171-
// var messages = List
172-
// .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new
173-
// McpSchema.TextContent("Test message")));
174-
// var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0);
175-
//
176-
// var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null,
177-
// McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(),
178-
// Map.of());
179-
//
180-
// StepVerifier.create(mcpAsyncServer.createMessage(request)).consumeNextWith(result
181-
// -> {
182-
// assertThat(result).isNotNull();
183-
// assertThat(result.role()).isEqualTo(Role.USER);
184-
// assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class);
185-
// assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test
186-
// message");
187-
// assertThat(result.model()).isEqualTo("MockModelName");
188-
// assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE);
189-
// }).verifyComplete();
190-
// }
87+
@ParameterizedTest(name = "{0} : {displayName} ")
88+
@ValueSource(strings = { "httpclient", "webflux" })
89+
void testCreateMessageWithoutSamplingCapabilities(String clientType) {
90+
91+
var clientBuilder = clientBulders.get(clientType);
92+
93+
McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification(
94+
new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {
95+
96+
exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block();
97+
98+
return Mono.just(mock(CallToolResult.class));
99+
});
100+
101+
McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build();
102+
103+
// Create client without sampling capabilities
104+
var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build();
105+
106+
assertThat(client.initialize()).isNotNull();
107+
108+
try {
109+
client.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
110+
}
111+
catch (McpError e) {
112+
assertThat(e).isInstanceOf(McpError.class)
113+
.hasMessage("Client must be configured with sampling capabilities");
114+
}
115+
}
116+
117+
@ParameterizedTest(name = "{0} : {displayName} ")
118+
@ValueSource(strings = { "httpclient", "webflux" })
119+
void testCreateMessageSuccess(String clientType) throws InterruptedException {
120+
121+
// Client
122+
var clientBuilder = clientBulders.get(clientType);
123+
124+
Function<CreateMessageRequest, CreateMessageResult> samplingHandler = request -> {
125+
assertThat(request.messages()).hasSize(1);
126+
assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class);
127+
128+
return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName",
129+
CreateMessageResult.StopReason.STOP_SEQUENCE);
130+
};
131+
132+
var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
133+
.capabilities(ClientCapabilities.builder().sampling().build())
134+
.sampling(samplingHandler)
135+
.build();
136+
137+
// Server
138+
139+
CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")),
140+
null);
141+
142+
McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification(
143+
new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {
144+
145+
var messages = List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER,
146+
new McpSchema.TextContent("Test message")));
147+
var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0);
148+
149+
var craeteMessageRequest = new McpSchema.CreateMessageRequest(messages, modelPrefs, null,
150+
McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(),
151+
Map.of());
152+
153+
StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> {
154+
assertThat(result).isNotNull();
155+
assertThat(result.role()).isEqualTo(Role.USER);
156+
assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class);
157+
assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message");
158+
assertThat(result.model()).isEqualTo("MockModelName");
159+
assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE);
160+
}).verifyComplete();
161+
162+
return Mono.just(callResponse);
163+
});
164+
165+
var mcpServer = McpServer.async(mcpServerTransportProvider)
166+
.serverInfo("test-server", "1.0.0")
167+
.tools(tool)
168+
.build();
169+
170+
InitializeResult initResult = mcpClient.initialize();
171+
assertThat(initResult).isNotNull();
172+
173+
CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
174+
175+
assertThat(response).isNotNull();
176+
assertThat(response).isEqualTo(callResponse);
177+
178+
mcpClient.close();
179+
mcpServer.close();
180+
}
191181

192182
// ---------------------------------------
193183
// Roots Tests
@@ -238,43 +228,44 @@ void testRootsSuccess(String clientType) {
238228
mcpServer.close();
239229
}
240230

241-
// @ParameterizedTest(name = "{0} : {displayName} ")
242-
// @ValueSource(strings = { "httpclient", "webflux" })
243-
// void testRootsWithoutCapability(String clientType) {
244-
// var clientBuilder = clientBulders.get(clientType);
245-
// AtomicReference<Exception> errorRef = new AtomicReference<>();
246-
//
247-
// var mcpServer =
248-
// McpServer.sync(mcpServerTransportProvider)
249-
// // TODO: implement tool handling and try to list roots
250-
// .tool(tool, (exchange, args) -> {
251-
// try {
252-
// exchange.listRoots();
253-
// } catch (Exception e) {
254-
// errorRef.set(e);
255-
// }
256-
// }).build();
257-
//
258-
// // Create client without roots capability
259-
// var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()) //
260-
// No
261-
// // roots
262-
// // capability
263-
// .build();
264-
//
265-
// InitializeResult initResult = mcpClient.initialize();
266-
// assertThat(initResult).isNotNull();
267-
//
268-
// assertThat(errorRef.get()).isInstanceOf(McpError.class).hasMessage("Roots not
269-
// supported");
270-
//
271-
// mcpClient.close();
272-
// mcpServer.close();
273-
// }
231+
@ParameterizedTest(name = "{0} : {displayName} ")
232+
@ValueSource(strings = { "httpclient", "webflux" })
233+
void testRootsWithoutCapability(String clientType) {
234+
235+
var clientBuilder = clientBulders.get(clientType);
236+
237+
McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(
238+
new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {
239+
240+
exchange.listRoots(); // try to list roots
241+
242+
return mock(CallToolResult.class);
243+
});
244+
245+
var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> {
246+
}).tools(tool).build();
247+
248+
// Create client without roots capability
249+
// No roots capability
250+
var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build();
251+
252+
assertThat(mcpClient.initialize()).isNotNull();
253+
254+
// Attempt to list roots should fail
255+
try {
256+
mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
257+
}
258+
catch (McpError e) {
259+
assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported");
260+
}
261+
262+
mcpClient.close();
263+
mcpServer.close();
264+
}
274265

275266
@ParameterizedTest(name = "{0} : {displayName} ")
276267
@ValueSource(strings = { "httpclient", "webflux" })
277-
void testRootsWithEmptyRootsList(String clientType) {
268+
void testRootsNotifciationWithEmptyRootsList(String clientType) {
278269
var clientBuilder = clientBulders.get(clientType);
279270

280271
AtomicReference<List<Root>> rootsRef = new AtomicReference<>();
@@ -474,8 +465,8 @@ void testToolListChangeHandlingSuccess(String clientType) {
474465
});
475466

476467
// Add a new tool
477-
McpServerFeatures.SyncToolRegistration tool2 = new McpServerFeatures.SyncToolRegistration(
478-
new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), request -> callResponse);
468+
McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification(
469+
new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), (exchange, request) -> callResponse);
479470

480471
mcpServer.addTool(tool2);
481472

@@ -487,4 +478,21 @@ void testToolListChangeHandlingSuccess(String clientType) {
487478
mcpServer.close();
488479
}
489480

481+
@ParameterizedTest(name = "{0} : {displayName} ")
482+
@ValueSource(strings = { "httpclient", "webflux" })
483+
void testInitialize(String clientType) {
484+
485+
var clientBuilder = clientBulders.get(clientType);
486+
487+
var mcpServer = McpServer.sync(mcpServerTransportProvider).build();
488+
489+
var mcpClient = clientBuilder.build();
490+
491+
InitializeResult initResult = mcpClient.initialize();
492+
assertThat(initResult).isNotNull();
493+
494+
mcpClient.close();
495+
mcpServer.close();
496+
}
497+
490498
}

0 commit comments

Comments
 (0)