Skip to content

feat(mcp): Custom context paths in HTTP Servlet SSE server transport #112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,14 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement
/** Event type for endpoint information */
public static final String ENDPOINT_EVENT_TYPE = "endpoint";

public static final String DEFAULT_BASE_URL = "";

/** JSON object mapper for serialization/deserialization */
private final ObjectMapper objectMapper;

/** Base URL for the server transport */
private final String baseUrl;

/** The endpoint path for handling client messages */
private final String messageEndpoint;

Expand All @@ -108,7 +113,22 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement
*/
public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint,
String sseEndpoint) {
this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint);
}

/**
* Creates a new HttpServletSseServerTransportProvider instance with a custom SSE
* endpoint.
* @param objectMapper The JSON object mapper to use for message
* serialization/deserialization
* @param baseUrl The base URL for the server transport
* @param messageEndpoint The endpoint path where clients will send their messages
* @param sseEndpoint The endpoint path where clients will establish SSE connections
*/
public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint) {
this.objectMapper = objectMapper;
this.baseUrl = baseUrl;
this.messageEndpoint = messageEndpoint;
this.sseEndpoint = sseEndpoint;
}
Expand Down Expand Up @@ -203,7 +223,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
this.sessions.put(sessionId, session);

// Send initial endpoint event
this.sendEvent(writer, ENDPOINT_EVENT_TYPE, messageEndpoint + "?sessionId=" + sessionId);
this.sendEvent(writer, ENDPOINT_EVENT_TYPE, this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId);
}

/**
Expand Down Expand Up @@ -449,6 +469,8 @@ public static class Builder {

private ObjectMapper objectMapper = new ObjectMapper();

private String baseUrl = DEFAULT_BASE_URL;

private String messageEndpoint;

private String sseEndpoint = DEFAULT_SSE_ENDPOINT;
Expand All @@ -464,6 +486,17 @@ public Builder objectMapper(ObjectMapper objectMapper) {
return this;
}

/**
* Sets the base URL for the server transport.
* @param baseUrl The base URL to use
* @return This builder instance for method chaining
*/
public Builder baseUrl(String baseUrl) {
Assert.notNull(baseUrl, "Base URL must not be null");
this.baseUrl = baseUrl;
return this;
}

/**
* Sets the endpoint path where clients will send their messages.
* @param messageEndpoint The message endpoint path
Expand Down Expand Up @@ -502,7 +535,7 @@ public HttpServletSseServerTransportProvider build() {
if (messageEndpoint == null) {
throw new IllegalStateException("MessageEndpoint must be set");
}
return new HttpServletSseServerTransportProvider(objectMapper, messageEndpoint, sseEndpoint);
return new HttpServletSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint);
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright 2024 - 2024 the original author or authors.
*/
package io.modelcontextprotocol.server.transport;

import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
import io.modelcontextprotocol.server.McpServer;
import io.modelcontextprotocol.spec.McpSchema;
import org.apache.catalina.Context;
import org.apache.catalina.LifecycleException;
import org.apache.catalina.LifecycleState;
import org.apache.catalina.startup.Tomcat;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import static org.assertj.core.api.Assertions.assertThat;

public class HttpServletSseServerCustomContextPathTests {

private static final int PORT = 8195;

private static final String CUSTOM_CONTEXT_PATH = "/api/v1";

private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse";

private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message";

private HttpServletSseServerTransportProvider mcpServerTransportProvider;

McpClient.SyncSpec clientBuilder;

private Tomcat tomcat;

@BeforeEach
public void before() {

// Create and configure the transport provider
mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder()
.objectMapper(new ObjectMapper())
.baseUrl(CUSTOM_CONTEXT_PATH)
.messageEndpoint(CUSTOM_MESSAGE_ENDPOINT)
.sseEndpoint(CUSTOM_SSE_ENDPOINT)
.build();

tomcat = TomcatTestUtil.createTomcatServer(CUSTOM_CONTEXT_PATH, PORT, mcpServerTransportProvider);

try {
tomcat.start();
assertThat(tomcat.getServer().getState() == LifecycleState.STARTED);
}
catch (Exception e) {
throw new RuntimeException("Failed to start Tomcat", e);
}

this.clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT)
.sseEndpoint(CUSTOM_CONTEXT_PATH + CUSTOM_SSE_ENDPOINT)
.build());
}

@AfterEach
public void after() {
if (mcpServerTransportProvider != null) {
mcpServerTransportProvider.closeGracefully().block();
}
if (tomcat != null) {
try {
tomcat.stop();
tomcat.destroy();
}
catch (LifecycleException e) {
throw new RuntimeException("Failed to stop Tomcat", e);
}
}
}

@Test
void testCustomContextPath() {
McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").build();
var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build();
assertThat(client.initialize()).isNotNull();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import io.modelcontextprotocol.spec.McpSchema.Root;
import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities;
import io.modelcontextprotocol.spec.McpSchema.Tool;
import org.apache.catalina.Context;
import org.apache.catalina.LifecycleException;
import org.apache.catalina.LifecycleState;
import org.apache.catalina.startup.Tomcat;
Expand Down Expand Up @@ -59,33 +58,15 @@ public class HttpServletSseServerTransportProviderIntegrationTests {

@BeforeEach
public void before() {
tomcat = new Tomcat();
tomcat.setPort(PORT);

String baseDir = System.getProperty("java.io.tmpdir");
tomcat.setBaseDir(baseDir);

Context context = tomcat.addContext("", baseDir);

// Create and configure the transport provider
mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder()
.objectMapper(new ObjectMapper())
.messageEndpoint(CUSTOM_MESSAGE_ENDPOINT)
.sseEndpoint(CUSTOM_SSE_ENDPOINT)
.build();

// Add transport servlet to Tomcat
org.apache.catalina.Wrapper wrapper = context.createWrapper();
wrapper.setName("mcpServlet");
wrapper.setServlet(mcpServerTransportProvider);
wrapper.setLoadOnStartup(1);
wrapper.setAsyncSupported(true);
context.addChild(wrapper);
context.addServletMappingDecoded("/*", "mcpServlet");

tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpServerTransportProvider);
try {
var connector = tomcat.getConnector();
connector.setAsyncTimeout(3000);
tomcat.start();
assertThat(tomcat.getServer().getState() == LifecycleState.STARTED);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright 2025 - 2025 the original author or authors.
*/
package io.modelcontextprotocol.server.transport;

import com.fasterxml.jackson.databind.ObjectMapper;
import jakarta.servlet.Servlet;
import org.apache.catalina.Context;
import org.apache.catalina.LifecycleState;
import org.apache.catalina.startup.Tomcat;

import static org.junit.Assert.assertThat;

/**
* @author Christian Tzolov
*/
public class TomcatTestUtil {

public static Tomcat createTomcatServer(String contextPath, int port, Servlet servlet) {

var tomcat = new Tomcat();
tomcat.setPort(port);

String baseDir = System.getProperty("java.io.tmpdir");
tomcat.setBaseDir(baseDir);

// Context context = tomcat.addContext("", baseDir);
Context context = tomcat.addContext(contextPath, baseDir);

// Add transport servlet to Tomcat
org.apache.catalina.Wrapper wrapper = context.createWrapper();
wrapper.setName("mcpServlet");
wrapper.setServlet(servlet);
wrapper.setLoadOnStartup(1);
wrapper.setAsyncSupported(true);
context.addChild(wrapper);
context.addServletMappingDecoded("/*", "mcpServlet");

var connector = tomcat.getConnector();
connector.setAsyncTimeout(3000);

return tomcat;
}

}