Skip to content

Optimize nested streams #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

This file was deleted.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransportProvider;
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransportProvider;
import io.modelcontextprotocol.server.McpServer;
import io.modelcontextprotocol.server.McpServerFeatures;
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
Expand Down Expand Up @@ -78,14 +78,14 @@ public void before() {
this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow();

clientBuilders.put("httpclient",
McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT)
McpClient.sync(HttpClientSseClientTransportProvider.builder("http://localhost:" + PORT)
.sseEndpoint(CUSTOM_SSE_ENDPOINT)
.build()));
clientBuilders.put("webflux",
McpClient
.sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT))
.sseEndpoint(CUSTOM_SSE_ENDPOINT)
.build()));
McpClient.sync(WebFluxSseClientTransportProvider
.builder(WebClient.builder().baseUrl("http://localhost:" + PORT))
.sseEndpoint(CUSTOM_SSE_ENDPOINT)
.build()));

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,15 @@

import java.time.Duration;

import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransportProvider;
import io.modelcontextprotocol.spec.McpClientTransportProvider;
import org.junit.jupiter.api.Timeout;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.containers.wait.strategy.Wait;

import org.springframework.web.reactive.function.client.WebClient;

/**
* Tests for the {@link McpAsyncClient} with {@link WebFluxSseClientTransport}.
*
* @author Christian Tzolov
*/
@Timeout(15) // Giving extra time beyond the client timeout
Expand All @@ -31,9 +29,8 @@ class WebFluxSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests {
.withExposedPorts(3001)
.waitingFor(Wait.forHttp("/").forStatusCode(404));

@Override
protected McpClientTransport createMcpTransport() {
return WebFluxSseClientTransport.builder(WebClient.builder().baseUrl(host)).build();
protected McpClientTransportProvider createMcpClientTransportProvider() {
return new WebFluxSseClientTransportProvider(WebClient.builder().baseUrl(host));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,15 @@

import java.time.Duration;

import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransportProvider;
import io.modelcontextprotocol.spec.McpClientTransportProvider;
import org.junit.jupiter.api.Timeout;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.containers.wait.strategy.Wait;

import org.springframework.web.reactive.function.client.WebClient;

/**
* Tests for the {@link McpSyncClient} with {@link WebFluxSseClientTransport}.
*
* @author Christian Tzolov
*/
@Timeout(15) // Giving extra time beyond the client timeout
Expand All @@ -32,8 +30,8 @@ class WebFluxSseMcpSyncClientTests extends AbstractMcpSyncClientTests {
.waitingFor(Wait.forHttp("/").forStatusCode(404));

@Override
protected McpClientTransport createMcpTransport() {
return WebFluxSseClientTransport.builder(WebClient.builder().baseUrl(host)).build();
protected McpClientTransportProvider createMcpClientTransportProvider() {
return new WebFluxSseClientTransportProvider(WebClient.builder().baseUrl(host));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpClientSession;
import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest;
import org.junit.jupiter.api.AfterEach;
Expand All @@ -31,8 +34,6 @@
import static org.assertj.core.api.Assertions.assertThatThrownBy;

/**
* Tests for the {@link WebFluxSseClientTransport} class.
*
* @author Christian Tzolov
*/
@Timeout(15)
Expand All @@ -46,20 +47,22 @@ class WebFluxSseClientTransportTests {
.withExposedPorts(3001)
.waitingFor(Wait.forHttp("/").forStatusCode(404));

private TestSseClientTransport transport;
private TestSseClientTransportProvider transportProvider;

private McpClientTransport transport;

private WebClient.Builder webClientBuilder;

private ObjectMapper objectMapper;

// Test class to access protected methods
static class TestSseClientTransport extends WebFluxSseClientTransport {
static class TestSseClientTransportProvider extends WebFluxSseClientTransportProvider {

private final AtomicInteger inboundMessageCount = new AtomicInteger(0);

private Sinks.Many<ServerSentEvent<String>> events = Sinks.many().unicast().onBackpressureBuffer();

public TestSseClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper) {
public TestSseClientTransportProvider(WebClient.Builder webClientBuilder, ObjectMapper objectMapper) {
super(webClientBuilder, objectMapper);
}

Expand All @@ -69,7 +72,7 @@ protected Flux<ServerSentEvent<String>> eventStream() {
}

public String getLastEndpoint() {
return messageEndpointSink.asMono().block();
return ((WebFluxSseClientTransport) getSession().getTransport()).messageEndpointSink.asMono().block();
}

public int getInboundMessageCount() {
Expand Down Expand Up @@ -99,7 +102,10 @@ void setUp() {
startContainer();
webClientBuilder = WebClient.builder().baseUrl(host);
objectMapper = new ObjectMapper();
transport = new TestSseClientTransport(webClientBuilder, objectMapper);
transportProvider = new TestSseClientTransportProvider(webClientBuilder, objectMapper);
transportProvider.setSessionFactory(
(transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
transport = transportProvider.getSession().getTransport();
transport.connect(Function.identity()).block();
}

Expand All @@ -117,44 +123,62 @@ void cleanup() {

@Test
void testEndpointEventHandling() {
assertThat(transport.getLastEndpoint()).startsWith("/message?");
assertThat(transportProvider.getLastEndpoint()).startsWith("/message?");
}

@Test
void constructorValidation() {
assertThatThrownBy(() -> new WebFluxSseClientTransport(null)).isInstanceOf(IllegalArgumentException.class)
assertThatThrownBy(() -> new WebFluxSseClientTransportProvider(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("WebClient.Builder must not be null");

assertThatThrownBy(() -> new WebFluxSseClientTransport(webClientBuilder, null))
assertThatThrownBy(() -> new WebFluxSseClientTransportProvider(webClientBuilder, null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("ObjectMapper must not be null");
}

@Test
void testBuilderPattern() {
// Test default builder
WebFluxSseClientTransport transport1 = WebFluxSseClientTransport.builder(webClientBuilder).build();
assertThatCode(() -> transport1.closeGracefully().block()).doesNotThrowAnyException();
WebFluxSseClientTransportProvider transportProvider1 = WebFluxSseClientTransportProvider
.builder(webClientBuilder)
.build();
transportProvider1.setSessionFactory(
(transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
transportProvider1.getSession();
assertThatCode(() -> transportProvider1.closeGracefully().block()).doesNotThrowAnyException();

// Test builder with custom ObjectMapper
ObjectMapper customMapper = new ObjectMapper();
WebFluxSseClientTransport transport2 = WebFluxSseClientTransport.builder(webClientBuilder)
WebFluxSseClientTransportProvider transportProvider2 = WebFluxSseClientTransportProvider
.builder(webClientBuilder)
.objectMapper(customMapper)
.build();
assertThatCode(() -> transport2.closeGracefully().block()).doesNotThrowAnyException();
transportProvider2.setSessionFactory(
(transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
transportProvider2.getSession();
assertThatCode(() -> transportProvider2.closeGracefully().block()).doesNotThrowAnyException();

// Test builder with custom SSE endpoint
WebFluxSseClientTransport transport3 = WebFluxSseClientTransport.builder(webClientBuilder)
WebFluxSseClientTransportProvider transportProvider3 = WebFluxSseClientTransportProvider
.builder(webClientBuilder)
.sseEndpoint("/custom-sse")
.build();
assertThatCode(() -> transport3.closeGracefully().block()).doesNotThrowAnyException();
transportProvider3.setSessionFactory(
(transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
transportProvider3.getSession();
assertThatCode(() -> transportProvider3.closeGracefully().block()).doesNotThrowAnyException();

// Test builder with all custom parameters
WebFluxSseClientTransport transport4 = WebFluxSseClientTransport.builder(webClientBuilder)
WebFluxSseClientTransportProvider transportProvider4 = WebFluxSseClientTransportProvider
.builder(webClientBuilder)
.objectMapper(customMapper)
.sseEndpoint("/custom-sse")
.build();
assertThatCode(() -> transport4.closeGracefully().block()).doesNotThrowAnyException();
transportProvider4.setSessionFactory(
(transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
transportProvider4.getSession();
assertThatCode(() -> transportProvider4.closeGracefully().block()).doesNotThrowAnyException();
}

@Test
Expand All @@ -164,7 +188,7 @@ void testMessageProcessing() {
Map.of("key", "value"));

// Simulate receiving the message
transport.simulateMessageEvent("""
transportProvider.simulateMessageEvent("""
{
"jsonrpc": "2.0",
"method": "test-method",
Expand All @@ -176,13 +200,13 @@ void testMessageProcessing() {
// Subscribe to messages and verify
StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete();

assertThat(transport.getInboundMessageCount()).isEqualTo(1);
assertThat(transportProvider.getInboundMessageCount()).isEqualTo(1);
}

@Test
void testResponseMessageProcessing() {
// Simulate receiving a response message
transport.simulateMessageEvent("""
transportProvider.simulateMessageEvent("""
{
"jsonrpc": "2.0",
"id": "test-id",
Expand All @@ -197,13 +221,13 @@ void testResponseMessageProcessing() {
// Verify message handling
StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete();

assertThat(transport.getInboundMessageCount()).isEqualTo(1);
assertThat(transportProvider.getInboundMessageCount()).isEqualTo(1);
}

@Test
void testErrorMessageProcessing() {
// Simulate receiving an error message
transport.simulateMessageEvent("""
transportProvider.simulateMessageEvent("""
{
"jsonrpc": "2.0",
"id": "test-id",
Expand All @@ -221,13 +245,13 @@ void testErrorMessageProcessing() {
// Verify message handling
StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete();

assertThat(transport.getInboundMessageCount()).isEqualTo(1);
assertThat(transportProvider.getInboundMessageCount()).isEqualTo(1);
}

@Test
void testNotificationMessageProcessing() {
// Simulate receiving a notification message (no id)
transport.simulateMessageEvent("""
transportProvider.simulateMessageEvent("""
{
"jsonrpc": "2.0",
"method": "update",
Expand All @@ -236,7 +260,7 @@ void testNotificationMessageProcessing() {
""");

// Verify the notification was processed
assertThat(transport.getInboundMessageCount()).isEqualTo(1);
assertThat(transportProvider.getInboundMessageCount()).isEqualTo(1);
}

@Test
Expand All @@ -252,27 +276,31 @@ void testGracefulShutdown() {
StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete();

// Message count should remain 0 after shutdown
assertThat(transport.getInboundMessageCount()).isEqualTo(0);
assertThat(transportProvider.getInboundMessageCount()).isEqualTo(0);
}

@Test
void testRetryBehavior() {
// Create a WebClient that simulates connection failures
WebClient.Builder failingWebClientBuilder = WebClient.builder().baseUrl("http://non-existent-host");

WebFluxSseClientTransport failingTransport = WebFluxSseClientTransport.builder(failingWebClientBuilder).build();
WebFluxSseClientTransportProvider failingTransportProvider = WebFluxSseClientTransportProvider
.builder(failingWebClientBuilder)
.build();
failingTransportProvider.setSessionFactory(
(transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));

// Verify that the transport attempts to reconnect
StepVerifier.create(Mono.delay(Duration.ofSeconds(2))).expectNextCount(1).verifyComplete();

// Clean up
failingTransport.closeGracefully().block();
failingTransportProvider.getSession().getTransport().closeGracefully().block();
}

@Test
void testMultipleMessageProcessing() {
// Simulate receiving multiple messages in sequence
transport.simulateMessageEvent("""
transportProvider.simulateMessageEvent("""
{
"jsonrpc": "2.0",
"method": "method1",
Expand All @@ -281,7 +309,7 @@ void testMultipleMessageProcessing() {
}
""");

transport.simulateMessageEvent("""
transportProvider.simulateMessageEvent("""
{
"jsonrpc": "2.0",
"method": "method2",
Expand All @@ -301,13 +329,13 @@ void testMultipleMessageProcessing() {
StepVerifier.create(transport.sendMessage(message1).then(transport.sendMessage(message2))).verifyComplete();

// Verify message count
assertThat(transport.getInboundMessageCount()).isEqualTo(2);
assertThat(transportProvider.getInboundMessageCount()).isEqualTo(2);
}

@Test
void testMessageOrderPreservation() {
// Simulate receiving messages in a specific order
transport.simulateMessageEvent("""
transportProvider.simulateMessageEvent("""
{
"jsonrpc": "2.0",
"method": "first",
Expand All @@ -316,7 +344,7 @@ void testMessageOrderPreservation() {
}
""");

transport.simulateMessageEvent("""
transportProvider.simulateMessageEvent("""
{
"jsonrpc": "2.0",
"method": "second",
Expand All @@ -325,7 +353,7 @@ void testMessageOrderPreservation() {
}
""");

transport.simulateMessageEvent("""
transportProvider.simulateMessageEvent("""
{
"jsonrpc": "2.0",
"method": "third",
Expand All @@ -335,7 +363,7 @@ void testMessageOrderPreservation() {
""");

// Verify message count and order
assertThat(transport.getInboundMessageCount()).isEqualTo(3);
assertThat(transportProvider.getInboundMessageCount()).isEqualTo(3);
}

}
Loading