Skip to content

refactor: adds Client/Reqeust customizer for HttpClientSseClientTransport #117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
import static org.awaitility.Awaitility.await;
import static org.mockito.Mockito.mock;

public class WebMvcSseIntegrationTests {
class WebMvcSseIntegrationTests {

private static final int PORT = 8183;

Expand Down Expand Up @@ -79,13 +79,13 @@ public void before() {

try {
tomcatServer.tomcat().start();
assertThat(tomcatServer.tomcat().getServer().getState() == LifecycleState.STARTED);
assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED);
}
catch (Exception e) {
throw new RuntimeException("Failed to start Tomcat", e);
}

clientBuilder = McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT));
clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT).build());

// Get the transport from Spring context
mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class);
Expand Down Expand Up @@ -200,8 +200,7 @@ void testCreateMessageSuccess() throws InterruptedException {

CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));

assertThat(response).isNotNull();
assertThat(response).isEqualTo(callResponse);
assertThat(response).isNotNull().isEqualTo(callResponse);

mcpClient.close();
mcpServer.close();
Expand Down Expand Up @@ -410,8 +409,7 @@ void testToolCallSuccess() {

CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));

assertThat(response).isNotNull();
assertThat(response).isEqualTo(callResponse);
assertThat(response).isNotNull().isEqualTo(callResponse);

mcpClient.close();
mcpServer.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;

import com.fasterxml.jackson.core.type.TypeReference;
Expand Down Expand Up @@ -100,59 +101,27 @@ public class HttpClientSseClientTransport implements McpClientTransport {
/** Holds the SSE connection future */
private final AtomicReference<CompletableFuture<Void>> connectionFuture = new AtomicReference<>();

/**
* Creates a new transport instance with default HTTP client and object mapper.
* @param baseUri the base URI of the MCP server
*/
public HttpClientSseClientTransport(String baseUri) {
this(HttpClient.newBuilder(), baseUri, new ObjectMapper());
}

/**
* Creates a new transport instance with custom HTTP client builder and object mapper.
* @param clientBuilder the HTTP client builder to use
* @param baseUri the base URI of the MCP server
* @param objectMapper the object mapper for JSON serialization/deserialization
* @throws IllegalArgumentException if objectMapper or clientBuilder is null
*/
public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, ObjectMapper objectMapper) {
this(clientBuilder, baseUri, DEFAULT_SSE_ENDPOINT, objectMapper);
}

/**
* Creates a new transport instance with custom HTTP client builder and object mapper.
* @param clientBuilder the HTTP client builder to use
* @param baseUri the base URI of the MCP server
* @param sseEndpoint the SSE endpoint path
* @param objectMapper the object mapper for JSON serialization/deserialization
* @throws IllegalArgumentException if objectMapper or clientBuilder is null
*/
public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, String sseEndpoint,
ObjectMapper objectMapper) {
this(clientBuilder, HttpRequest.newBuilder(), baseUri, sseEndpoint, objectMapper);
}

/**
* Creates a new transport instance with custom HTTP client builder, object mapper,
* and headers.
* @param clientBuilder the HTTP client builder to use
* @param httpClient the HTTP client to use
* @param requestBuilder the HTTP request builder to use
* @param baseUri the base URI of the MCP server
* @param sseEndpoint the SSE endpoint path
* @param objectMapper the object mapper for JSON serialization/deserialization
* @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null
*/
public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpRequest.Builder requestBuilder,
String baseUri, String sseEndpoint, ObjectMapper objectMapper) {
HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to deprecate it first in order to avoid breaking the users. It is currently public.

Copy link
Contributor Author

@Aliaksie Aliaksie Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right — to avoid breaking existing users, I've marked all currently public constructors as @Deprecated and added clear Javadoc indicating that they will be removed in a future release.
In addition, I introduced an alternative package-private constructor, which will be used in the new builder-based approach.

String sseEndpoint, ObjectMapper objectMapper) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
Assert.hasText(baseUri, "baseUri must not be empty");
Assert.hasText(sseEndpoint, "sseEndpoint must not be empty");
Assert.notNull(clientBuilder, "clientBuilder must not be null");
Assert.notNull(httpClient, "clientBuilder must not be null");
Assert.notNull(requestBuilder, "requestBuilder must not be null");
this.baseUri = baseUri;
this.sseEndpoint = sseEndpoint;
this.objectMapper = objectMapper;
this.httpClient = clientBuilder.connectTimeout(Duration.ofSeconds(10)).build();
this.httpClient = httpClient;
this.requestBuilder = requestBuilder;

this.sseClient = new FlowSseClient(this.httpClient, requestBuilder);
Expand All @@ -176,17 +145,20 @@ public static class Builder {

private String sseEndpoint = DEFAULT_SSE_ENDPOINT;

private HttpClient.Builder clientBuilder = HttpClient.newBuilder();
private HttpClient.Builder clientBuilder = HttpClient.newBuilder()
.version(HttpClient.Version.HTTP_1_1)
.connectTimeout(Duration.ofSeconds(10));

private ObjectMapper objectMapper = new ObjectMapper();

private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder();
private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder()
.header("Content-Type", "application/json");

/**
* Creates a new builder with the specified base URI.
* @param baseUri the base URI of the MCP server
*/
public Builder(String baseUri) {
Builder(String baseUri) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to deprecate it first in order to avoid breaking the users. It is currently public.

Copy link
Contributor Author

@Aliaksie Aliaksie Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right — to avoid breaking existing users, I've marked all currently public constructors as @Deprecated and added clear Javadoc indicating that they will be removed in a future release.
In addition, I introduced an alternative package-private constructor, which will be used in the new builder-based approach.

Assert.hasText(baseUri, "baseUri must not be empty");
this.baseUri = baseUri;
}
Expand All @@ -213,6 +185,17 @@ public Builder clientBuilder(HttpClient.Builder clientBuilder) {
return this;
}

/**
* Customizes the HTTP client builder.
* @param clientCustomizer the consumer to customize the HTTP client builder
* @return this builder
*/
public Builder customizeClient(final Consumer<HttpClient.Builder> clientCustomizer) {
Assert.notNull(clientCustomizer, "clientCustomizer must not be null");
clientCustomizer.accept(clientBuilder);
return this;
}

/**
* Sets the HTTP request builder.
* @param requestBuilder the HTTP request builder
Expand All @@ -224,6 +207,17 @@ public Builder requestBuilder(HttpRequest.Builder requestBuilder) {
return this;
}

/**
* Customizes the HTTP client builder.
* @param requestCustomizer the consumer to customize the HTTP request builder
* @return this builder
*/
public Builder customizeRequest(final Consumer<HttpRequest.Builder> requestCustomizer) {
Assert.notNull(requestCustomizer, "requestCustomizer must not be null");
requestCustomizer.accept(requestBuilder);
return this;
}

/**
* Sets the object mapper for JSON serialization/deserialization.
* @param objectMapper the object mapper
Expand All @@ -240,7 +234,8 @@ public Builder objectMapper(ObjectMapper objectMapper) {
* @return a new transport instance
*/
public HttpClientSseClientTransport build() {
return new HttpClientSseClientTransport(clientBuilder, requestBuilder, baseUri, sseEndpoint, objectMapper);
return new HttpClientSseClientTransport(clientBuilder.build(), requestBuilder, baseUri, sseEndpoint,
objectMapper);
}

}
Expand Down Expand Up @@ -336,7 +331,6 @@ public Mono<Void> sendMessage(JSONRPCMessage message) {
try {
String jsonText = this.objectMapper.writeValueAsString(message);
HttpRequest request = this.requestBuilder.uri(URI.create(this.baseUri + endpoint))
.header("Content-Type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(jsonText))
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class HttpSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests {

@Override
protected McpClientTransport createMcpTransport() {
return new HttpClientSseClientTransport(host);
return HttpClientSseClientTransport.builder(host).build();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class HttpSseMcpSyncClientTests extends AbstractMcpSyncClientTests {

@Override
protected McpClientTransport createMcpTransport() {
return new HttpClientSseClientTransport(host);
return HttpClientSseClientTransport.builder(host).build();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

package io.modelcontextprotocol.client.transport;

import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.time.Duration;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
Expand All @@ -26,6 +28,8 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;

import com.fasterxml.jackson.databind.ObjectMapper;

/**
* Tests for the {@link HttpClientSseClientTransport} class.
*
Expand All @@ -51,8 +55,8 @@ static class TestHttpClientSseClientTransport extends HttpClientSseClientTranspo

private Sinks.Many<ServerSentEvent<String>> events = Sinks.many().unicast().onBackpressureBuffer();

public TestHttpClientSseClientTransport(String baseUri) {
super(baseUri);
public TestHttpClientSseClientTransport(final String baseUri) {
super(HttpClient.newHttpClient(), HttpRequest.newBuilder(), baseUri, "/sse", new ObjectMapper());
}

public int getInboundMessageCount() {
Expand Down Expand Up @@ -191,13 +195,14 @@ void testGracefulShutdown() {
StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete();

// Message count should remain 0 after shutdown
assertThat(transport.getInboundMessageCount()).isEqualTo(0);
assertThat(transport.getInboundMessageCount()).isZero();
}

@Test
void testRetryBehavior() {
// Create a client that simulates connection failures
HttpClientSseClientTransport failingTransport = new HttpClientSseClientTransport("http://non-existent-host");
HttpClientSseClientTransport failingTransport = HttpClientSseClientTransport.builder("http://non-existent-host")
.build();

// Verify that the transport attempts to reconnect
StepVerifier.create(Mono.delay(Duration.ofSeconds(2))).expectNextCount(1).verifyComplete();
Expand Down