diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index a64b4a35..e52fc88b 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -80,9 +80,14 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement /** Event type for endpoint information */ public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + public static final String DEFAULT_BASE_URL = ""; + /** JSON object mapper for serialization/deserialization */ private final ObjectMapper objectMapper; + /** Base URL for the server transport */ + private final String baseUrl; + /** The endpoint path for handling client messages */ private final String messageEndpoint; @@ -108,7 +113,22 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement */ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { + this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); + } + + /** + * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE + * endpoint. + * @param objectMapper The JSON object mapper to use for message + * serialization/deserialization + * @param baseUrl The base URL for the server transport + * @param messageEndpoint The endpoint path where clients will send their messages + * @param sseEndpoint The endpoint path where clients will establish SSE connections + */ + public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint) { this.objectMapper = objectMapper; + this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; } @@ -203,7 +223,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) this.sessions.put(sessionId, session); // Send initial endpoint event - this.sendEvent(writer, ENDPOINT_EVENT_TYPE, messageEndpoint + "?sessionId=" + sessionId); + this.sendEvent(writer, ENDPOINT_EVENT_TYPE, this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId); } /** @@ -449,6 +469,8 @@ public static class Builder { private ObjectMapper objectMapper = new ObjectMapper(); + private String baseUrl = DEFAULT_BASE_URL; + private String messageEndpoint; private String sseEndpoint = DEFAULT_SSE_ENDPOINT; @@ -464,6 +486,17 @@ public Builder objectMapper(ObjectMapper objectMapper) { return this; } + /** + * Sets the base URL for the server transport. + * @param baseUrl The base URL to use + * @return This builder instance for method chaining + */ + public Builder baseUrl(String baseUrl) { + Assert.notNull(baseUrl, "Base URL must not be null"); + this.baseUrl = baseUrl; + return this; + } + /** * Sets the endpoint path where clients will send their messages. * @param messageEndpoint The message endpoint path @@ -502,7 +535,7 @@ public HttpServletSseServerTransportProvider build() { if (messageEndpoint == null) { throw new IllegalStateException("MessageEndpoint must be set"); } - return new HttpServletSseServerTransportProvider(objectMapper, messageEndpoint, sseEndpoint); + return new HttpServletSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java new file mode 100644 index 00000000..1254e2ad --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java @@ -0,0 +1,86 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server.transport; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.spec.McpSchema; +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +public class HttpServletSseServerCustomContextPathTests { + + private static final int PORT = 8195; + + private static final String CUSTOM_CONTEXT_PATH = "/api/v1"; + + private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; + + private HttpServletSseServerTransportProvider mcpServerTransportProvider; + + McpClient.SyncSpec clientBuilder; + + private Tomcat tomcat; + + @BeforeEach + public void before() { + + // Create and configure the transport provider + mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .baseUrl(CUSTOM_CONTEXT_PATH) + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build(); + + tomcat = TomcatTestUtil.createTomcatServer(CUSTOM_CONTEXT_PATH, PORT, mcpServerTransportProvider); + + try { + tomcat.start(); + assertThat(tomcat.getServer().getState() == LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + this.clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(CUSTOM_CONTEXT_PATH + CUSTOM_SSE_ENDPOINT) + .build()); + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Test + void testCustomContextPath() { + McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").build(); + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); + assertThat(client.initialize()).isNotNull(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index 1cd395e7..b04940c7 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -26,7 +26,6 @@ import io.modelcontextprotocol.spec.McpSchema.Root; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; -import org.apache.catalina.Context; import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; import org.apache.catalina.startup.Tomcat; @@ -59,14 +58,6 @@ public class HttpServletSseServerTransportProviderIntegrationTests { @BeforeEach public void before() { - tomcat = new Tomcat(); - tomcat.setPort(PORT); - - String baseDir = System.getProperty("java.io.tmpdir"); - tomcat.setBaseDir(baseDir); - - Context context = tomcat.addContext("", baseDir); - // Create and configure the transport provider mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() .objectMapper(new ObjectMapper()) @@ -74,18 +65,8 @@ public void before() { .sseEndpoint(CUSTOM_SSE_ENDPOINT) .build(); - // Add transport servlet to Tomcat - org.apache.catalina.Wrapper wrapper = context.createWrapper(); - wrapper.setName("mcpServlet"); - wrapper.setServlet(mcpServerTransportProvider); - wrapper.setLoadOnStartup(1); - wrapper.setAsyncSupported(true); - context.addChild(wrapper); - context.addServletMappingDecoded("/*", "mcpServlet"); - + tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpServerTransportProvider); try { - var connector = tomcat.getConnector(); - connector.setAsyncTimeout(3000); tomcat.start(); assertThat(tomcat.getServer().getState() == LifecycleState.STARTED); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java new file mode 100644 index 00000000..6f922dfa --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java @@ -0,0 +1,45 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ +package io.modelcontextprotocol.server.transport; + +import com.fasterxml.jackson.databind.ObjectMapper; +import jakarta.servlet.Servlet; +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; + +import static org.junit.Assert.assertThat; + +/** + * @author Christian Tzolov + */ +public class TomcatTestUtil { + + public static Tomcat createTomcatServer(String contextPath, int port, Servlet servlet) { + + var tomcat = new Tomcat(); + tomcat.setPort(port); + + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + // Context context = tomcat.addContext("", baseDir); + Context context = tomcat.addContext(contextPath, baseDir); + + // Add transport servlet to Tomcat + org.apache.catalina.Wrapper wrapper = context.createWrapper(); + wrapper.setName("mcpServlet"); + wrapper.setServlet(servlet); + wrapper.setLoadOnStartup(1); + wrapper.setAsyncSupported(true); + context.addChild(wrapper); + context.addServletMappingDecoded("/*", "mcpServlet"); + + var connector = tomcat.getConnector(); + connector.setAsyncTimeout(3000); + + return tomcat; + } + +}