Skip to content

Pass request builder to constructor to support custom headers #86

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
public class FlowSseClient {

private final HttpClient httpClient;
private final HttpRequest.Builder requestBuilder;

/**
* Pattern to extract the data content from SSE data field lines. Matches lines
Expand Down Expand Up @@ -92,7 +93,17 @@ public interface SseEventHandler {
* @param httpClient the {@link HttpClient} instance to use for SSE connections
*/
public FlowSseClient(HttpClient httpClient) {
this(httpClient, HttpRequest.newBuilder());
}

/**
* Creates a new FlowSseClient with the specified HTTP client and request builder.
* @param httpClient the {@link HttpClient} instance to use for SSE connections
* @param requestBuilder the {@link HttpRequest.Builder} to use for SSE requests
*/
public FlowSseClient(HttpClient httpClient, HttpRequest.Builder requestBuilder) {
this.httpClient = httpClient;
this.requestBuilder = requestBuilder;
}

/**
Expand All @@ -109,7 +120,7 @@ public FlowSseClient(HttpClient httpClient) {
* @throws RuntimeException if the connection fails with a non-200 status code
*/
public void subscribe(String url, SseEventHandler eventHandler) {
HttpRequest request = HttpRequest.newBuilder()
HttpRequest request = requestBuilder
.uri(URI.create(url))
.header("Accept", "text/event-stream")
.header("Cache-Control", "no-cache")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ public class HttpClientSseClientTransport implements McpClientTransport {
*/
private final HttpClient httpClient;

/** HTTP request builder for building requests to send messages to the server */
private final HttpRequest.Builder requestBuilder;

/** JSON object mapper for message serialization/deserialization */
protected ObjectMapper objectMapper;

Expand Down Expand Up @@ -126,15 +129,32 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String bas
*/
public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String baseUri, String sseEndpoint,
ObjectMapper objectMapper) {
this(clientBuilder, HttpRequest.newBuilder(), baseUri, sseEndpoint, objectMapper);
}

/**
* Creates a new transport instance with custom HTTP client builder, object mapper, and headers.
* @param clientBuilder the HTTP client builder to use
* @param requestBuilder the HTTP request builder to use
* @param baseUri the base URI of the MCP server
* @param sseEndpoint the SSE endpoint path
* @param objectMapper the object mapper for JSON serialization/deserialization
* @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null
*/
public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpRequest.Builder requestBuilder,
String baseUri, String sseEndpoint, ObjectMapper objectMapper) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
Assert.hasText(baseUri, "baseUri must not be empty");
Assert.hasText(sseEndpoint, "sseEndpoint must not be empty");
Assert.notNull(clientBuilder, "clientBuilder must not be null");
Assert.notNull(requestBuilder, "requestBuilder must not be null");
this.baseUri = baseUri;
this.sseEndpoint = sseEndpoint;
this.objectMapper = objectMapper;
this.httpClient = clientBuilder.connectTimeout(Duration.ofSeconds(10)).build();
this.sseClient = new FlowSseClient(this.httpClient);
this.requestBuilder = requestBuilder;

this.sseClient = new FlowSseClient(this.httpClient, requestBuilder);
}

/**
Expand All @@ -159,6 +179,8 @@ public static class Builder {

private ObjectMapper objectMapper = new ObjectMapper();

private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder();

/**
* Creates a new builder with the specified base URI.
* @param baseUri the base URI of the MCP server
Expand Down Expand Up @@ -190,6 +212,17 @@ public Builder clientBuilder(HttpClient.Builder clientBuilder) {
return this;
}

/**
* Sets the HTTP request builder.
* @param requestBuilder the HTTP request builder
* @return this builder
*/
public Builder requestBuilder(HttpRequest.Builder requestBuilder) {
Assert.notNull(requestBuilder, "requestBuilder must not be null");
this.requestBuilder = requestBuilder;
return this;
}

/**
* Sets the object mapper for JSON serialization/deserialization.
* @param objectMapper the object mapper
Expand All @@ -206,7 +239,7 @@ public Builder objectMapper(ObjectMapper objectMapper) {
* @return a new transport instance
*/
public HttpClientSseClientTransport build() {
return new HttpClientSseClientTransport(clientBuilder, baseUri, sseEndpoint, objectMapper);
return new HttpClientSseClientTransport(clientBuilder, requestBuilder, baseUri, sseEndpoint, objectMapper);
}

}
Expand Down Expand Up @@ -301,7 +334,7 @@ public Mono<Void> sendMessage(JSONRPCMessage message) {

try {
String jsonText = this.objectMapper.writeValueAsString(message);
HttpRequest request = HttpRequest.newBuilder()
HttpRequest request = this.requestBuilder
.uri(URI.create(this.baseUri + endpoint))
.header("Content-Type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(jsonText))
Expand Down