diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index f9190fd7..47bf7a2c 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -44,7 +44,7 @@ import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.mock; -public class WebMvcSseIntegrationTests { +class WebMvcSseIntegrationTests { private static final int PORT = 8183; @@ -79,13 +79,13 @@ public void before() { try { tomcatServer.tomcat().start(); - assertThat(tomcatServer.tomcat().getServer().getState() == LifecycleState.STARTED); + assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); } catch (Exception e) { throw new RuntimeException("Failed to start Tomcat", e); } - clientBuilder = McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT)); + clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT).build()); // Get the transport from Spring context mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class); @@ -200,8 +200,7 @@ void testCreateMessageSuccess() throws InterruptedException { CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); + assertThat(response).isNotNull().isEqualTo(callResponse); mcpClient.close(); mcpServer.close(); @@ -410,8 +409,7 @@ void testToolCallSuccess() { CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); + assertThat(response).isNotNull().isEqualTo(callResponse); mcpClient.close(); mcpServer.close(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index a5bdd43e..632d3844 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -13,6 +13,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import java.util.function.Function; import com.fasterxml.jackson.core.type.TypeReference; @@ -103,7 +104,10 @@ public class HttpClientSseClientTransport implements McpClientTransport { /** * Creates a new transport instance with default HTTP client and object mapper. * @param baseUri the base URI of the MCP server + * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This + * constructor will be removed in future versions. */ + @Deprecated(forRemoval = true) public HttpClientSseClientTransport(String baseUri) { this(HttpClient.newBuilder(), baseUri, new ObjectMapper()); } @@ -114,7 +118,10 @@ public HttpClientSseClientTransport(String baseUri) { * @param baseUri the base URI of the MCP server * @param objectMapper the object mapper for JSON serialization/deserialization * @throws IllegalArgumentException if objectMapper or clientBuilder is null + * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This + * constructor will be removed in future versions. */ + @Deprecated(forRemoval = true) public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, ObjectMapper objectMapper) { this(clientBuilder, baseUri, DEFAULT_SSE_ENDPOINT, objectMapper); } @@ -126,7 +133,10 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String bas * @param sseEndpoint the SSE endpoint path * @param objectMapper the object mapper for JSON serialization/deserialization * @throws IllegalArgumentException if objectMapper or clientBuilder is null + * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This + * constructor will be removed in future versions. */ + @Deprecated(forRemoval = true) public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, String sseEndpoint, ObjectMapper objectMapper) { this(clientBuilder, HttpRequest.newBuilder(), baseUri, sseEndpoint, objectMapper); @@ -141,18 +151,37 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String bas * @param sseEndpoint the SSE endpoint path * @param objectMapper the object mapper for JSON serialization/deserialization * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null + * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This + * constructor will be removed in future versions. */ + @Deprecated(forRemoval = true) public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpRequest.Builder requestBuilder, String baseUri, String sseEndpoint, ObjectMapper objectMapper) { + this(clientBuilder.connectTimeout(Duration.ofSeconds(10)).build(), requestBuilder, baseUri, sseEndpoint, + objectMapper); + } + + /** + * Creates a new transport instance with custom HTTP client builder, object mapper, + * and headers. + * @param httpClient the HTTP client to use + * @param requestBuilder the HTTP request builder to use + * @param baseUri the base URI of the MCP server + * @param sseEndpoint the SSE endpoint path + * @param objectMapper the object mapper for JSON serialization/deserialization + * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null + */ + HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, + String sseEndpoint, ObjectMapper objectMapper) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.hasText(baseUri, "baseUri must not be empty"); Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); - Assert.notNull(clientBuilder, "clientBuilder must not be null"); + Assert.notNull(httpClient, "httpClient must not be null"); Assert.notNull(requestBuilder, "requestBuilder must not be null"); this.baseUri = baseUri; this.sseEndpoint = sseEndpoint; this.objectMapper = objectMapper; - this.httpClient = clientBuilder.connectTimeout(Duration.ofSeconds(10)).build(); + this.httpClient = httpClient; this.requestBuilder = requestBuilder; this.sseClient = new FlowSseClient(this.httpClient, requestBuilder); @@ -164,7 +193,7 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques * @return a new builder instance */ public static Builder builder(String baseUri) { - return new Builder(baseUri); + return new Builder().baseUri(baseUri); } /** @@ -172,25 +201,50 @@ public static Builder builder(String baseUri) { */ public static class Builder { - private final String baseUri; + private String baseUri; private String sseEndpoint = DEFAULT_SSE_ENDPOINT; - private HttpClient.Builder clientBuilder = HttpClient.newBuilder(); + private HttpClient.Builder clientBuilder = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_1_1) + .connectTimeout(Duration.ofSeconds(10)); private ObjectMapper objectMapper = new ObjectMapper(); - private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(); + private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder() + .header("Content-Type", "application/json"); + + /** + * Creates a new builder instance. + */ + Builder() { + // Default constructor + } /** * Creates a new builder with the specified base URI. * @param baseUri the base URI of the MCP server + * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. + * This constructor is deprecated and will be removed or made {@code protected} or + * {@code private} in a future release. */ + @Deprecated(forRemoval = true) public Builder(String baseUri) { Assert.hasText(baseUri, "baseUri must not be empty"); this.baseUri = baseUri; } + /** + * Sets the base URI. + * @param baseUri the base URI + * @return this builder + */ + Builder baseUri(String baseUri) { + Assert.hasText(baseUri, "baseUri must not be empty"); + this.baseUri = baseUri; + return this; + } + /** * Sets the SSE endpoint path. * @param sseEndpoint the SSE endpoint path @@ -213,6 +267,17 @@ public Builder clientBuilder(HttpClient.Builder clientBuilder) { return this; } + /** + * Customizes the HTTP client builder. + * @param clientCustomizer the consumer to customize the HTTP client builder + * @return this builder + */ + public Builder customizeClient(final Consumer clientCustomizer) { + Assert.notNull(clientCustomizer, "clientCustomizer must not be null"); + clientCustomizer.accept(clientBuilder); + return this; + } + /** * Sets the HTTP request builder. * @param requestBuilder the HTTP request builder @@ -224,6 +289,17 @@ public Builder requestBuilder(HttpRequest.Builder requestBuilder) { return this; } + /** + * Customizes the HTTP client builder. + * @param requestCustomizer the consumer to customize the HTTP request builder + * @return this builder + */ + public Builder customizeRequest(final Consumer requestCustomizer) { + Assert.notNull(requestCustomizer, "requestCustomizer must not be null"); + requestCustomizer.accept(requestBuilder); + return this; + } + /** * Sets the object mapper for JSON serialization/deserialization. * @param objectMapper the object mapper @@ -240,7 +316,8 @@ public Builder objectMapper(ObjectMapper objectMapper) { * @return a new transport instance */ public HttpClientSseClientTransport build() { - return new HttpClientSseClientTransport(clientBuilder, requestBuilder, baseUri, sseEndpoint, objectMapper); + return new HttpClientSseClientTransport(clientBuilder.build(), requestBuilder, baseUri, sseEndpoint, + objectMapper); } } @@ -336,7 +413,6 @@ public Mono sendMessage(JSONRPCMessage message) { try { String jsonText = this.objectMapper.writeValueAsString(message); HttpRequest request = this.requestBuilder.uri(URI.create(this.baseUri + endpoint)) - .header("Content-Type", "application/json") .POST(HttpRequest.BodyPublishers.ofString(jsonText)) .build(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java index 15749d4f..3b7275cc 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java @@ -29,7 +29,7 @@ class HttpSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { @Override protected McpClientTransport createMcpTransport() { - return new HttpClientSseClientTransport(host); + return HttpClientSseClientTransport.builder(host).build(); } @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java index 067f9295..204cf298 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java @@ -29,7 +29,7 @@ class HttpSseMcpSyncClientTests extends AbstractMcpSyncClientTests { @Override protected McpClientTransport createMcpTransport() { - return new HttpClientSseClientTransport(host); + return HttpClientSseClientTransport.builder(host).build(); } @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index 294056fb..b9648bc4 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol.client.transport; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; import java.time.Duration; import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; @@ -26,6 +28,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; +import com.fasterxml.jackson.databind.ObjectMapper; + /** * Tests for the {@link HttpClientSseClientTransport} class. * @@ -51,8 +55,8 @@ static class TestHttpClientSseClientTransport extends HttpClientSseClientTranspo private Sinks.Many> events = Sinks.many().unicast().onBackpressureBuffer(); - public TestHttpClientSseClientTransport(String baseUri) { - super(baseUri); + public TestHttpClientSseClientTransport(final String baseUri) { + super(HttpClient.newHttpClient(), HttpRequest.newBuilder(), baseUri, "/sse", new ObjectMapper()); } public int getInboundMessageCount() { @@ -191,13 +195,14 @@ void testGracefulShutdown() { StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); // Message count should remain 0 after shutdown - assertThat(transport.getInboundMessageCount()).isEqualTo(0); + assertThat(transport.getInboundMessageCount()).isZero(); } @Test void testRetryBehavior() { // Create a client that simulates connection failures - HttpClientSseClientTransport failingTransport = new HttpClientSseClientTransport("http://non-existent-host"); + HttpClientSseClientTransport failingTransport = HttpClientSseClientTransport.builder("http://non-existent-host") + .build(); // Verify that the transport attempts to reconnect StepVerifier.create(Mono.delay(Duration.ofSeconds(2))).expectNextCount(1).verifyComplete();