From 8ceb495252553d907edd314c2066544645602be8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 5 Mar 2025 12:03:58 +0100 Subject: [PATCH 01/20] Refactor server side to handle multiple clients MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit WARN: This is still work in progress and does not compile. Breaking changes: * McpAsyncServer * getClientCapabilities deprecated + throws * getClientInfo deprecated + throws * listRoots deprecated + throws * createMessage deprecated + throws * McpTransport * connect deprecated - should only belong to McpClientTransport * ServerMcpTransport * connect default implementation that throws The major change is the introduction of ServerMcpSession for per-client communication. The user should be exposed to a limited abstraction that hides the session called ServerMcpExchange which currently exposes sampling and roots. Signed-off-by: Dariusz Jędrzejczyk --- .../transport/WebFluxSseServerTransport.java | 211 +++++++--------- .../server/McpAsyncServer.java | 128 +++++----- .../spec/ClientMcpTransport.java | 6 + .../spec/DefaultMcpSession.java | 1 + .../spec/McpTransport.java | 5 +- .../spec/ServerMcpExchange.java | 75 ++++++ .../spec/ServerMcpSession.java | 225 ++++++++++++++++++ .../spec/ServerMcpTransport.java | 14 ++ 8 files changed, 472 insertions(+), 193 deletions(-) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpExchange.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpSession.java diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java index bed7293e..09db3ba7 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java @@ -1,21 +1,22 @@ package io.modelcontextprotocol.server.transport; import java.io.IOException; -import java.time.Duration; -import java.util.List; +import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; -import java.util.function.Function; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ServerMcpSession; import io.modelcontextprotocol.spec.ServerMcpTransport; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.Exceptions; import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; @@ -88,18 +89,22 @@ public class WebFluxSseServerTransport implements ServerMcpTransport { private final RouterFunction routerFunction; + private ServerMcpSession.InitHandler initHandler; + + private Map> requestHandlers; + + private Map notificationHandlers; + /** * Map of active client sessions, keyed by session ID. */ - private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); /** * Flag indicating if the transport is shutting down. */ private volatile boolean isClosing = false; - private Function, Mono> connectHandler; - /** * Constructs a new WebFlux SSE server transport instance. * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization @@ -137,21 +142,13 @@ public WebFluxSseServerTransport(ObjectMapper objectMapper, String messageEndpoi this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); } - /** - * Configures the message handler for this transport. In the WebFlux SSE - * implementation, this method stores the handler for processing incoming messages but - * doesn't establish any connections since the server accepts connections rather than - * initiating them. - * @param handler A function that processes incoming JSON-RPC messages and returns - * responses. This handler will be called for each message received through the - * message endpoint. - * @return An empty Mono since the server doesn't initiate connections - */ @Override - public Mono connect(Function, Mono> handler) { - this.connectHandler = handler; - // Server-side transport doesn't initiate connections - return Mono.empty().then(); + public void registerHandlers(ServerMcpSession.InitHandler initHandler, + Map> requestHandlers, + Map notificationHandlers) { + this.initHandler = initHandler; + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; } /** @@ -178,36 +175,14 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { return Mono.empty(); } - return Mono.create(sink -> { - try {// @formatter:off - String jsonText = objectMapper.writeValueAsString(message); - ServerSentEvent event = ServerSentEvent.builder() - .event(MESSAGE_EVENT_TYPE) - .data(jsonText) - .build(); - - logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); - - List failedSessions = sessions.values().stream() - .filter(session -> session.messageSink.tryEmitNext(event).isFailure()) - .map(session -> session.id) - .toList(); + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); - if (failedSessions.isEmpty()) { - logger.debug("Successfully broadcast message to all sessions"); - sink.success(); - } - else { - String error = "Failed to broadcast message to sessions: " + String.join(", ", failedSessions); - logger.error(error); - sink.error(new RuntimeException(error)); - } // @formatter:on - } - catch (IOException e) { - logger.error("Failed to serialize message: {}", e.getMessage()); - sink.error(e); - } - }); + return Flux.fromStream(sessions.values().stream()) + .flatMap(session -> session.sendMessage(message) + .doOnError(e -> logger.error("Failed to " + "send message to session {}: {}", session.sessionId, + e.getMessage())) + .onErrorComplete()) + .then(); } /** @@ -241,18 +216,10 @@ public T unmarshalFrom(Object data, TypeReference typeRef) { */ @Override public Mono closeGracefully() { - return Mono.fromRunnable(() -> { - isClosing = true; - logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); - }).then(Mono.when(sessions.values().stream().map(session -> { - String sessionId = session.id; - return Mono.fromRunnable(() -> session.close()) - .then(Mono.delay(Duration.ofMillis(100))) - .then(Mono.fromRunnable(() -> sessions.remove(sessionId))); - }).toList())) - .timeout(Duration.ofSeconds(5)) - .doOnSuccess(v -> logger.debug("Graceful shutdown completed")) - .doOnError(e -> logger.error("Error during graceful shutdown: {}", e.getMessage())); + return Flux.fromIterable(sessions.values()) + .doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size())) + .doOnNext(WebFluxMcpSession::close) + .then(); } /** @@ -291,38 +258,22 @@ private Mono handleSseConnection(ServerRequest request) { if (isClosing) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); } - String sessionId = UUID.randomUUID().toString(); - logger.debug("Creating new SSE connection for session: {}", sessionId); - ClientSession session = new ClientSession(sessionId); - this.sessions.put(sessionId, session); return ServerResponse.ok() .contentType(MediaType.TEXT_EVENT_STREAM) .body(Flux.>create(sink -> { + String sessionId = UUID.randomUUID().toString(); + logger.debug("Creating new SSE connection for session: {}", sessionId); + WebFluxMcpSession session = new WebFluxMcpSession(sessionId, sink, initHandler, requestHandlers, + notificationHandlers); + sessions.put(sessionId, session); + // Send initial endpoint event logger.debug("Sending initial endpoint event to session: {}", sessionId); - sink.next(ServerSentEvent.builder().event(ENDPOINT_EVENT_TYPE).data(messageEndpoint).build()); - - // Subscribe to session messages - session.messageSink.asFlux() - .doOnSubscribe(s -> logger.debug("Session {} subscribed to message sink", sessionId)) - .doOnComplete(() -> { - logger.debug("Session {} completed", sessionId); - sessions.remove(sessionId); - }) - .doOnError(error -> { - logger.error("Error in session {}: {}", sessionId, error.getMessage()); - sessions.remove(sessionId); - }) - .doOnCancel(() -> { - logger.debug("Session {} cancelled", sessionId); - sessions.remove(sessionId); - }) - .subscribe(event -> { - logger.debug("Forwarding event to session {}: {}", sessionId, event); - sink.next(event); - }, sink::error, sink::complete); - + sink.next(ServerSentEvent.builder() + .event(ENDPOINT_EVENT_TYPE) + .data(messageEndpoint + "?sessionId=" + sessionId) + .build()); sink.onCancel(() -> { logger.debug("Session {} cancelled", sessionId); sessions.remove(sessionId); @@ -350,17 +301,20 @@ private Mono handleMessage(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); } + if (request.queryParam("sessionId").isEmpty()) { + return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing in message endpoint")); + } + + ServerMcpSession session = sessions.get(request.queryParam("sessionId").get()); + return request.bodyToMono(String.class).flatMap(body -> { try { McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); - return Mono.just(message) - .transform(this.connectHandler) - .flatMap(response -> ServerResponse.ok().build()) - .onErrorResume(error -> { - logger.error("Error processing message: {}", error.getMessage()); - return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) - .bodyValue(new McpError(error.getMessage())); - }); + return session.handle(message).flatMap(response -> ServerResponse.ok().build()).onErrorResume(error -> { + logger.error("Error processing message: {}", error.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .bodyValue(new McpError(error.getMessage())); + }); } catch (IllegalArgumentException | IOException e) { logger.error("Failed to deserialize message: {}", e.getMessage()); @@ -369,40 +323,49 @@ private Mono handleMessage(ServerRequest request) { }); } - /** - * Represents an active client SSE connection session. Manages the message sink for - * sending events to the client and handles session lifecycle. - * - *

- * Each session: - *

    - *
  • Has a unique identifier
  • - *
  • Maintains its own message sink for event broadcasting
  • - *
  • Supports clean shutdown through the close method
  • - *
- */ - private static class ClientSession { + private class WebFluxMcpSession extends ServerMcpSession { - private final String id; + final String sessionId; - private final Sinks.Many> messageSink; + private final FluxSink> sink; - ClientSession(String id) { - this.id = id; - logger.debug("Creating new session: {}", id); - this.messageSink = Sinks.many().replay().latest(); - logger.debug("Session {} initialized with replay sink", id); + public WebFluxMcpSession(String sessionId, FluxSink> sink, InitHandler initHandler, + Map> requestHandlers, Map notificationHandlers) { + super(WebFluxSseServerTransport.this, initHandler, requestHandlers, notificationHandlers); + this.sessionId = sessionId; + this.sink = sink; } - void close() { - logger.debug("Closing session: {}", id); - Sinks.EmitResult result = messageSink.tryEmitComplete(); - if (result.isFailure()) { - logger.warn("Failed to complete message sink for session {}: {}", id, result); - } - else { - logger.debug("Successfully completed message sink for session {}", id); - } + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return Mono.fromSupplier(() -> { + try { + return objectMapper.writeValueAsString(message); + } + catch (IOException e) { + throw Exceptions.propagate(e); + } + }).doOnNext(jsonText -> { + ServerSentEvent event = ServerSentEvent.builder() + .event(MESSAGE_EVENT_TYPE) + .data(jsonText) + .build(); + sink.next(event); + }).doOnError(e -> { + // TODO log with sessionid + Throwable exception = Exceptions.unwrap(e); + sink.error(exception); + }).then(); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(sink::complete); + } + + @Override + public void close() { + sink.complete(); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 7b691678..63b19156 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -4,7 +4,6 @@ package io.modelcontextprotocol.server; -import java.time.Duration; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -17,8 +16,8 @@ import io.modelcontextprotocol.spec.DefaultMcpSession; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ServerMcpSession; import io.modelcontextprotocol.spec.ServerMcpTransport; -import io.modelcontextprotocol.spec.DefaultMcpSession.NotificationHandler; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; @@ -75,22 +74,12 @@ public class McpAsyncServer { private static final Logger logger = LoggerFactory.getLogger(McpAsyncServer.class); - /** - * The MCP session implementation that manages bidirectional JSON-RPC communication - * between clients and servers. - */ - private final DefaultMcpSession mcpSession; - private final ServerMcpTransport transport; private final McpSchema.ServerCapabilities serverCapabilities; private final McpSchema.Implementation serverInfo; - private McpSchema.ClientCapabilities clientCapabilities; - - private McpSchema.Implementation clientInfo; - /** * Thread-safe list of tool handlers that can be modified at runtime. */ @@ -115,7 +104,6 @@ public class McpAsyncServer { * @param features The MCP server supported features. */ McpAsyncServer(ServerMcpTransport mcpTransport, McpServerFeatures.Async features) { - this.serverInfo = features.serverInfo(); this.serverCapabilities = features.serverCapabilities(); this.tools.addAll(features.tools()); @@ -123,13 +111,12 @@ public class McpAsyncServer { this.resourceTemplates.addAll(features.resourceTemplates()); this.prompts.putAll(features.prompts()); - Map> requestHandlers = new HashMap<>(); + Map> requestHandlers = new HashMap<>(); // Initialize request handlers for standard MCP methods - requestHandlers.put(McpSchema.METHOD_INITIALIZE, asyncInitializeRequestHandler()); // Ping MUST respond with an empty data, but not NULL response. - requestHandlers.put(McpSchema.METHOD_PING, (params) -> Mono.just("")); + requestHandlers.put(McpSchema.METHOD_PING, (exchange, params) -> Mono.just("")); // Add tools API handlers if the tool capability is enabled if (this.serverCapabilities.tools() != null) { @@ -155,9 +142,9 @@ public class McpAsyncServer { requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); } - Map notificationHandlers = new HashMap<>(); + Map notificationHandlers = new HashMap<>(); - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (params) -> Mono.empty()); + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); List, Mono>> rootsChangeConsumers = features.rootsChangeConsumers(); @@ -170,20 +157,21 @@ public class McpAsyncServer { asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); this.transport = mcpTransport; - this.mcpSession = new DefaultMcpSession(Duration.ofSeconds(10), mcpTransport, requestHandlers, - notificationHandlers); + mcpTransport.registerHandlers(this::asyncInitializeRequestHandler, requestHandlers, notificationHandlers); } // --------------------------------------- // Lifecycle Management // --------------------------------------- - private DefaultMcpSession.RequestHandler asyncInitializeRequestHandler() { - return params -> { + private Mono asyncInitializeRequestHandler( + ServerMcpSession.ClientInitConsumer initConsumer, Object params) { + return Mono.defer(() -> { McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(params, new TypeReference() { }); - this.clientCapabilities = initializeRequest.capabilities(); - this.clientInfo = initializeRequest.clientInfo(); + + initConsumer.init(initializeRequest.capabilities(), initializeRequest.clientInfo()); + logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", initializeRequest.protocolVersion(), initializeRequest.capabilities(), initializeRequest.clientInfo()); @@ -205,7 +193,7 @@ private DefaultMcpSession.RequestHandler asyncInitia return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, this.serverInfo, null)); - }; + }); } /** @@ -228,16 +216,18 @@ public McpSchema.Implementation getServerInfo() { * Get the client capabilities that define the supported features and functionality. * @return The client capabilities */ + @Deprecated public ClientCapabilities getClientCapabilities() { - return this.clientCapabilities; + throw new IllegalStateException("This method is deprecated and should not be called"); } /** * Get the client implementation information. * @return The client implementation details */ + @Deprecated public McpSchema.Implementation getClientInfo() { - return this.clientInfo; + throw new IllegalStateException("This method is deprecated and should not be called"); } /** @@ -245,14 +235,14 @@ public McpSchema.Implementation getClientInfo() { * @return A Mono that completes when the server has been closed */ public Mono closeGracefully() { - return this.mcpSession.closeGracefully(); + return this.transport.closeGracefully(); } /** * Close the server immediately. */ public void close() { - this.mcpSession.close(); + this.transport.close(); } private static final TypeReference LIST_ROOTS_RESULT_TYPE_REF = new TypeReference<>() { @@ -271,20 +261,21 @@ public Mono listRoots() { * @param cursor Optional pagination cursor from a previous list request * @return A Mono that emits the list of roots result containing */ + @Deprecated public Mono listRoots(String cursor) { - return this.mcpSession.sendRequest(McpSchema.METHOD_ROOTS_LIST, new McpSchema.PaginatedRequest(cursor), - LIST_ROOTS_RESULT_TYPE_REF); + return Mono.error(new RuntimeException("Not implemented")); } - private NotificationHandler asyncRootsListChangedNotificationHandler( + private ServerMcpSession.NotificationHandler asyncRootsListChangedNotificationHandler( List, Mono>> rootsChangeConsumers) { - return params -> listRoots().flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) - .flatMap(consumer -> consumer.apply(listRootsResult.roots())) - .onErrorResume(error -> { - logger.error("Error handling roots list change notification", error); - return Mono.empty(); - }) - .then()); + return (exchange, + params) -> listRoots().flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) + .flatMap(consumer -> consumer.apply(listRootsResult.roots())) + .onErrorResume(error -> { + logger.error("Error handling roots list change notification", error); + return Mono.empty(); + }) + .then()); } // --------------------------------------- @@ -358,19 +349,21 @@ public Mono removeTool(String toolName) { * @return A Mono that completes when all clients have been notified */ public Mono notifyToolsListChanged() { - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); + McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); + return this.transport.sendMessage(jsonrpcNotification); } - private DefaultMcpSession.RequestHandler toolsListRequestHandler() { - return params -> { + private ServerMcpSession.RequestHandler toolsListRequestHandler() { + return (exchange, params) -> { List tools = this.tools.stream().map(McpServerFeatures.AsyncToolRegistration::tool).toList(); return Mono.just(new McpSchema.ListToolsResult(tools, null)); }; } - private DefaultMcpSession.RequestHandler toolsCallRequestHandler() { - return params -> { + private ServerMcpSession.RequestHandler toolsCallRequestHandler() { + return (exchange, params) -> { McpSchema.CallToolRequest callToolRequest = transport.unmarshalFrom(params, new TypeReference() { }); @@ -450,11 +443,13 @@ public Mono removeResource(String resourceUri) { * @return A Mono that completes when all clients have been notified */ public Mono notifyResourcesListChanged() { - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); + McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); + return this.transport.sendMessage(jsonrpcNotification); } - private DefaultMcpSession.RequestHandler resourcesListRequestHandler() { - return params -> { + private ServerMcpSession.RequestHandler resourcesListRequestHandler() { + return (exchange, params) -> { var resourceList = this.resources.values() .stream() .map(McpServerFeatures.AsyncResourceRegistration::resource) @@ -463,13 +458,13 @@ private DefaultMcpSession.RequestHandler resource }; } - private DefaultMcpSession.RequestHandler resourceTemplateListRequestHandler() { - return params -> Mono.just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null)); + private ServerMcpSession.RequestHandler resourceTemplateListRequestHandler() { + return (exchange, params) -> Mono.just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null)); } - private DefaultMcpSession.RequestHandler resourcesReadRequestHandler() { - return params -> { + private ServerMcpSession.RequestHandler resourcesReadRequestHandler() { + return (exchange, params) -> { McpSchema.ReadResourceRequest resourceRequest = transport.unmarshalFrom(params, new TypeReference() { }); @@ -553,11 +548,13 @@ public Mono removePrompt(String promptName) { * @return A Mono that completes when all clients have been notified */ public Mono notifyPromptsListChanged() { - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); + McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); + return this.transport.sendMessage(jsonrpcNotification); } - private DefaultMcpSession.RequestHandler promptsListRequestHandler() { - return params -> { + private ServerMcpSession.RequestHandler promptsListRequestHandler() { + return (exchange, params) -> { // TODO: Implement pagination // McpSchema.PaginatedRequest request = transport.unmarshalFrom(params, // new TypeReference() { @@ -572,8 +569,8 @@ private DefaultMcpSession.RequestHandler promptsLis }; } - private DefaultMcpSession.RequestHandler promptsGetRequestHandler() { - return params -> { + private ServerMcpSession.RequestHandler promptsGetRequestHandler() { + return (exchange, params) -> { McpSchema.GetPromptRequest promptRequest = transport.unmarshalFrom(params, new TypeReference() { }); @@ -612,7 +609,9 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN return Mono.empty(); } - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_MESSAGE, params); + McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_NOTIFICATION_MESSAGE, params); + return this.transport.sendMessage(jsonrpcNotification); } /** @@ -620,8 +619,8 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN * not be sent. * @return A handler that processes logging level change requests */ - private DefaultMcpSession.RequestHandler setLoggerRequestHandler() { - return params -> { + private ServerMcpSession.RequestHandler setLoggerRequestHandler() { + return (exchange, params) -> { this.minLoggingLevel = transport.unmarshalFrom(params, new TypeReference() { }); @@ -654,16 +653,9 @@ private DefaultMcpSession.RequestHandler setLoggerRequestHandler() { * "https://spec.modelcontextprotocol.io/specification/client/sampling/">Sampling * Specification */ + @Deprecated public Mono createMessage(McpSchema.CreateMessageRequest createMessageRequest) { - - if (this.clientCapabilities == null) { - return Mono.error(new McpError("Client must be initialized. Call the initialize method first!")); - } - if (this.clientCapabilities.sampling() == null) { - return Mono.error(new McpError("Client must be configured with sampling capabilities")); - } - return this.mcpSession.sendRequest(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, createMessageRequest, - CREATE_MESSAGE_RESULT_TYPE_REF); + return Mono.error(new RuntimeException("Not implemented")); } /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java index 8a9b4ce0..24767f1f 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java @@ -3,6 +3,10 @@ */ package io.modelcontextprotocol.spec; +import java.util.function.Function; + +import reactor.core.publisher.Mono; + /** * Marker interface for the client-side MCP transport. * @@ -10,4 +14,6 @@ */ public interface ClientMcpTransport extends McpTransport { + Mono connect(Function, Mono> handler); + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java index 46aefafc..add33d7a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java @@ -35,6 +35,7 @@ * @author Christian Tzolov * @author Dariusz Jędrzejczyk */ +// TODO: DefaultMcpSession is only relevant to the client-side. public class DefaultMcpSession implements McpSession { /** Logger for this class */ diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java index 344a50bf..886f4be0 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java @@ -46,7 +46,10 @@ public interface McpTransport { * This method should be called before any message exchange can occur. It sets up the * necessary resources and establishes the connection to the server. *

+ * @deprecated This is only relevant for client-side transports and will be removed + * from this interface. */ + @Deprecated Mono connect(Function, Mono> handler); /** @@ -69,7 +72,7 @@ default void close() { Mono closeGracefully(); /** - * Sends a message to the server asynchronously. + * Sends a message to the peer asynchronously. * *

* This method handles the transmission of messages to the server in an asynchronous diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpExchange.java b/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpExchange.java new file mode 100644 index 00000000..4facdb1c --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpExchange.java @@ -0,0 +1,75 @@ +package io.modelcontextprotocol.spec; + +import com.fasterxml.jackson.core.type.TypeReference; +import reactor.core.publisher.Mono; + +public class ServerMcpExchange { + + private final ServerMcpSession session; + + private final McpSchema.ClientCapabilities clientCapabilities; + + private final McpSchema.Implementation clientInfo; + + public ServerMcpExchange(ServerMcpSession session, McpSchema.ClientCapabilities clientCapabilities, + McpSchema.Implementation clientInfo) { + this.session = session; + this.clientCapabilities = clientCapabilities; + this.clientInfo = clientInfo; + } + + private static final TypeReference CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() { + }; + + /** + * Create a new message using the sampling capabilities of the client. The Model + * Context Protocol (MCP) provides a standardized way for servers to request LLM + * sampling (“completions” or “generations”) from language models via clients. This + * flow allows clients to maintain control over model access, selection, and + * permissions while enabling servers to leverage AI capabilities—with no server API + * keys necessary. Servers can request text or image-based interactions and optionally + * include context from MCP servers in their prompts. + * @param createMessageRequest The request to create a new message + * @return A Mono that completes when the message has been created + * @throws McpError if the client has not been initialized or does not support + * sampling capabilities + * @throws McpError if the client does not support the createMessage method + * @see McpSchema.CreateMessageRequest + * @see McpSchema.CreateMessageResult + * @see Sampling + * Specification + */ + public Mono createMessage(McpSchema.CreateMessageRequest createMessageRequest) { + if (this.clientCapabilities == null) { + return Mono.error(new McpError("Client must be initialized. Call the initialize method first!")); + } + if (this.clientCapabilities.sampling() == null) { + return Mono.error(new McpError("Client must be configured with sampling capabilities")); + } + return this.session.sendRequest(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, createMessageRequest, + CREATE_MESSAGE_RESULT_TYPE_REF); + } + + private static final TypeReference LIST_ROOTS_RESULT_TYPE_REF = new TypeReference<>() { + }; + + /** + * Retrieves the list of all roots provided by the client. + * @return A Mono that emits the list of roots result. + */ + public Mono listRoots() { + return this.listRoots(null); + } + + /** + * Retrieves a paginated list of roots provided by the server. + * @param cursor Optional pagination cursor from a previous list request + * @return A Mono that emits the list of roots result containing + */ + public Mono listRoots(String cursor) { + return this.session.sendRequest(McpSchema.METHOD_ROOTS_LIST, new McpSchema.PaginatedRequest(cursor), + LIST_ROOTS_RESULT_TYPE_REF); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpSession.java new file mode 100644 index 00000000..1e910d2b --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpSession.java @@ -0,0 +1,225 @@ +package io.modelcontextprotocol.spec; + +import java.time.Duration; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; + +import com.fasterxml.jackson.core.type.TypeReference; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoSink; +import reactor.core.publisher.Sinks; + +public abstract class ServerMcpSession implements McpSession { + + private static final Logger logger = LoggerFactory.getLogger(ServerMcpSession.class); + + private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); + + private final String sessionPrefix = UUID.randomUUID().toString().substring(0, 8); + + private final AtomicLong requestCounter = new AtomicLong(0); + + private final InitHandler initHandler; + + private final Map> requestHandlers; + + private final Map notificationHandlers; + + // TODO: used only to unmarshall - could be extracted to another interface + private final McpTransport transport; + + private final Sinks.One exchangeSink = Sinks.one(); + + volatile boolean isInitialized = false; + + public ServerMcpSession(McpTransport transport, InitHandler initHandler, + Map> requestHandlers, Map notificationHandlers) { + this.transport = transport; + this.initHandler = initHandler; + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + } + + public void init(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { + exchangeSink.tryEmitValue(new ServerMcpExchange(this, clientCapabilities, clientInfo)); + } + + public Mono exchange() { + return exchangeSink.asMono(); + } + + protected abstract Mono sendMessage(McpSchema.JSONRPCMessage message); + + private String generateRequestId() { + return this.sessionPrefix + "-" + this.requestCounter.getAndIncrement(); + } + + public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + String requestId = this.generateRequestId(); + + return Mono.create(sink -> { + this.pendingResponses.put(requestId, sink); + McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, + requestId, requestParams); + this.sendMessage(jsonrpcRequest).subscribe(v -> { + }, error -> { + this.pendingResponses.remove(requestId); + sink.error(error); + }); + }).timeout(Duration.ofSeconds(10)).handle((jsonRpcResponse, sink) -> { + if (jsonRpcResponse.error() != null) { + sink.error(new McpError(jsonRpcResponse.error())); + } + else { + if (typeRef.getType().equals(Void.class)) { + sink.complete(); + } + else { + sink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef)); + } + } + }); + } + + @Override + public Mono sendNotification(String method, Map params) { + McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + method, params); + return this.sendMessage(jsonrpcNotification); + } + + public Mono handle(McpSchema.JSONRPCMessage message) { + return Mono.defer(() -> { + // TODO handle errors for communication to without initialization happening + // first + if (message instanceof McpSchema.JSONRPCResponse response) { + logger.debug("Received Response: {}", response); + var sink = pendingResponses.remove(response.id()); + if (sink == null) { + logger.warn("Unexpected response for unknown id {}", response.id()); + } + else { + sink.success(response); + } + return Mono.empty(); + } + else if (message instanceof McpSchema.JSONRPCRequest request) { + logger.debug("Received request: {}", request); + return handleIncomingRequest(request).onErrorResume(error -> { + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), null)); + // TODO: Should the error go to SSE or back as POST return? + return this.sendMessage(errorResponse).then(Mono.empty()); + }).flatMap(this::sendMessage); + } + else if (message instanceof McpSchema.JSONRPCNotification notification) { + // TODO handle errors for communication to without initialization + // happening first + logger.debug("Received notification: {}", notification); + // TODO: in case of error, should the POST request be signalled? + return handleIncomingNotification(notification) + .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage())); + } + else { + logger.warn("Received unknown message type: {}", message); + return Mono.empty(); + } + }); + } + + /** + * Handles an incoming JSON-RPC request by routing it to the appropriate handler. + * @param request The incoming JSON-RPC request + * @return A Mono containing the JSON-RPC response + */ + private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request) { + return Mono.defer(() -> { + Mono resultMono; + if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { + // TODO handle situation where already initialized! + resultMono = this.initHandler.handle(new ClientInitConsumer(), request.params()) + .doOnNext(initResult -> this.isInitialized = true); + } + else { + // TODO handle errors for communication to without initialization + // happening first + var handler = this.requestHandlers.get(request.method()); + if (handler == null) { + MethodNotFoundError error = getMethodNotFoundError(request.method()); + return Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, + error.message(), error.data()))); + } + + resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, request.params())); + } + return resultMono + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) + .onErrorResume(error -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), + null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), null)))); // TODO: add error message + // through the data field + }); + } + + /** + * Handles an incoming JSON-RPC notification by routing it to the appropriate handler. + * @param notification The incoming JSON-RPC notification + * @return A Mono that completes when the notification is processed + */ + private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification) { + return Mono.defer(() -> { + var handler = notificationHandlers.get(notification.method()); + if (handler == null) { + logger.error("No handler registered for notification method: {}", notification.method()); + return Mono.empty(); + } + return handler.handle(this, notification.params()); + }); + } + + record MethodNotFoundError(String method, String message, Object data) { + } + + static MethodNotFoundError getMethodNotFoundError(String method) { + switch (method) { + case McpSchema.METHOD_ROOTS_LIST: + return new MethodNotFoundError(method, "Roots not supported", + Map.of("reason", "Client does not have roots capability")); + default: + return new MethodNotFoundError(method, "Method not found: " + method, null); + } + } + + public class ClientInitConsumer { + + public void init(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { + ServerMcpSession.this.init(clientCapabilities, clientInfo); + } + + } + + public interface InitHandler { + + Mono handle(ClientInitConsumer clientInitConsumer, Object params); + + } + + public interface NotificationHandler { + + Mono handle(ServerMcpSession connection, Object params); + + } + + public interface RequestHandler { + + Mono handle(ServerMcpExchange exchange, Object params); + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java index 13591432..0c2069f3 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java @@ -3,6 +3,11 @@ */ package io.modelcontextprotocol.spec; +import java.util.Map; +import java.util.function.Function; + +import reactor.core.publisher.Mono; + /** * Marker interface for the server-side MCP transport. * @@ -10,4 +15,13 @@ */ public interface ServerMcpTransport extends McpTransport { + @Override + default Mono connect(Function, Mono> handler) { + throw new IllegalStateException("Server transport does not support connect method"); + } + + void registerHandlers(ServerMcpSession.InitHandler initHandler, + Map> requestHandlers, + Map notificationHandlers); + } From 8316ad8b2755f452957f6e05b08d9dc85d098455 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Thu, 6 Mar 2025 14:13:25 +0100 Subject: [PATCH 02/20] Introduce Child server transport and Session Factory --- .../transport/WebFluxSseServerTransport.java | 52 +++++++++++-------- .../server/McpAsyncServer.java | 9 ++-- .../spec/ServerMcpSession.java | 40 +++++++++----- .../spec/ServerMcpTransport.java | 11 ++-- 4 files changed, 68 insertions(+), 44 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java index 09db3ba7..65b02bba 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java @@ -1,7 +1,6 @@ package io.modelcontextprotocol.server.transport; import java.io.IOException; -import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; @@ -89,16 +88,24 @@ public class WebFluxSseServerTransport implements ServerMcpTransport { private final RouterFunction routerFunction; - private ServerMcpSession.InitHandler initHandler; - - private Map> requestHandlers; - - private Map notificationHandlers; + private ServerMcpSession.Factory sessionFactory; /** * Map of active client sessions, keyed by session ID. */ - private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + // FIXME: This is a bit clumsy. The McpAsyncServer handles global notifications + // using the transport and we need access to child transports for each session to + // use the sendMessage method. Ideally, the particular transport would be an + // abstraction of a specialized session that can handle only notifications and we + // could delegate to all child sessions without directly going through the transport. + // The conversion from a notification to message happens both in McpAsyncServer + // and in ServerMcpSession and it would be beneficial to have a unified interface + // for both. An MCP server implementation can use both McpServerExchange and + // Mcp(Sync|Async)Server to send notifications so the capability needs to lie in + // both places. + private final ConcurrentHashMap sessionTransports = new ConcurrentHashMap<>(); /** * Flag indicating if the transport is shutting down. @@ -143,12 +150,8 @@ public WebFluxSseServerTransport(ObjectMapper objectMapper, String messageEndpoi } @Override - public void registerHandlers(ServerMcpSession.InitHandler initHandler, - Map> requestHandlers, - Map notificationHandlers) { - this.initHandler = initHandler; - this.requestHandlers = requestHandlers; - this.notificationHandlers = notificationHandlers; + public void setSessionFactory(ServerMcpSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; } /** @@ -177,7 +180,7 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); - return Flux.fromStream(sessions.values().stream()) + return Flux.fromStream(sessionTransports.values().stream()) .flatMap(session -> session.sendMessage(message) .doOnError(e -> logger.error("Failed to " + "send message to session {}: {}", session.sessionId, e.getMessage())) @@ -218,7 +221,7 @@ public T unmarshalFrom(Object data, TypeReference typeRef) { public Mono closeGracefully() { return Flux.fromIterable(sessions.values()) .doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size())) - .doOnNext(WebFluxMcpSession::close) + .flatMap(ServerMcpSession::closeGracefully) .then(); } @@ -264,9 +267,11 @@ private Mono handleSseConnection(ServerRequest request) { .body(Flux.>create(sink -> { String sessionId = UUID.randomUUID().toString(); logger.debug("Creating new SSE connection for session: {}", sessionId); - WebFluxMcpSession session = new WebFluxMcpSession(sessionId, sink, initHandler, requestHandlers, - notificationHandlers); - sessions.put(sessionId, session); + WebFluxMcpSessionTransport + sessionTransport = new WebFluxMcpSessionTransport(sessionId, sink); + + sessions.put(sessionId, sessionFactory.create(sessionTransport)); + sessionTransports.put(sessionId, sessionTransport); // Send initial endpoint event logger.debug("Sending initial endpoint event to session: {}", sessionId); @@ -323,15 +328,13 @@ private Mono handleMessage(ServerRequest request) { }); } - private class WebFluxMcpSession extends ServerMcpSession { + private class WebFluxMcpSessionTransport implements ServerMcpTransport.Child { final String sessionId; private final FluxSink> sink; - public WebFluxMcpSession(String sessionId, FluxSink> sink, InitHandler initHandler, - Map> requestHandlers, Map notificationHandlers) { - super(WebFluxSseServerTransport.this, initHandler, requestHandlers, notificationHandlers); + public WebFluxMcpSessionTransport(String sessionId, FluxSink> sink) { this.sessionId = sessionId; this.sink = sink; } @@ -358,6 +361,11 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { }).then(); } + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return WebFluxSseServerTransport.this.unmarshalFrom(data, typeRef); + } + @Override public Mono closeGracefully() { return Mono.fromRunnable(sink::complete); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 63b19156..4072e504 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -157,18 +157,17 @@ public class McpAsyncServer { asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); this.transport = mcpTransport; - mcpTransport.registerHandlers(this::asyncInitializeRequestHandler, requestHandlers, notificationHandlers); + mcpTransport.setSessionFactory(transport -> new ServerMcpSession(transport, + this::asyncInitializeRequestHandler, requestHandlers, notificationHandlers)); } // --------------------------------------- // Lifecycle Management // --------------------------------------- private Mono asyncInitializeRequestHandler( - ServerMcpSession.ClientInitConsumer initConsumer, Object params) { + ServerMcpSession.ClientInitConsumer initConsumer, McpSchema.InitializeRequest initializeRequest) { return Mono.defer(() -> { - McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(params, - new TypeReference() { - }); + initConsumer.init(initializeRequest.capabilities(), initializeRequest.clientInfo()); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpSession.java index 1e910d2b..8265343a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpSession.java @@ -13,7 +13,7 @@ import reactor.core.publisher.MonoSink; import reactor.core.publisher.Sinks; -public abstract class ServerMcpSession implements McpSession { +public class ServerMcpSession implements McpSession { private static final Logger logger = LoggerFactory.getLogger(ServerMcpSession.class); @@ -48,12 +48,6 @@ public void init(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Impl exchangeSink.tryEmitValue(new ServerMcpExchange(this, clientCapabilities, clientInfo)); } - public Mono exchange() { - return exchangeSink.asMono(); - } - - protected abstract Mono sendMessage(McpSchema.JSONRPCMessage message); - private String generateRequestId() { return this.sessionPrefix + "-" + this.requestCounter.getAndIncrement(); } @@ -65,7 +59,7 @@ public Mono sendRequest(String method, Object requestParams, TypeReferenc this.pendingResponses.put(requestId, sink); McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, requestId, requestParams); - this.sendMessage(jsonrpcRequest).subscribe(v -> { + this.transport.sendMessage(jsonrpcRequest).subscribe(v -> { }, error -> { this.pendingResponses.remove(requestId); sink.error(error); @@ -89,7 +83,7 @@ public Mono sendRequest(String method, Object requestParams, TypeReferenc public Mono sendNotification(String method, Map params) { McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, method, params); - return this.sendMessage(jsonrpcNotification); + return this.transport.sendMessage(jsonrpcNotification); } public Mono handle(McpSchema.JSONRPCMessage message) { @@ -114,8 +108,8 @@ else if (message instanceof McpSchema.JSONRPCRequest request) { new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, error.getMessage(), null)); // TODO: Should the error go to SSE or back as POST return? - return this.sendMessage(errorResponse).then(Mono.empty()); - }).flatMap(this::sendMessage); + return this.transport.sendMessage(errorResponse).then(Mono.empty()); + }).flatMap(this.transport::sendMessage); } else if (message instanceof McpSchema.JSONRPCNotification notification) { // TODO handle errors for communication to without initialization @@ -142,7 +136,11 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR Mono resultMono; if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { // TODO handle situation where already initialized! - resultMono = this.initHandler.handle(new ClientInitConsumer(), request.params()) + McpSchema.InitializeRequest initializeRequest = + transport.unmarshalFrom(request.params(), + new TypeReference() { + }); + resultMono = this.initHandler.handle(new ClientInitConsumer(), initializeRequest) .doOnNext(initResult -> this.isInitialized = true); } else { @@ -196,6 +194,16 @@ static MethodNotFoundError getMethodNotFoundError(String method) { } } + @Override + public Mono closeGracefully() { + return this.transport.closeGracefully(); + } + + @Override + public void close() { + this.transport.close(); + } + public class ClientInitConsumer { public void init(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { @@ -206,7 +214,8 @@ public void init(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Impl public interface InitHandler { - Mono handle(ClientInitConsumer clientInitConsumer, Object params); + Mono handle(ClientInitConsumer clientInitConsumer, + McpSchema.InitializeRequest initializeRequest); } @@ -222,4 +231,9 @@ public interface RequestHandler { } + @FunctionalInterface + public interface Factory { + ServerMcpSession create(ServerMcpTransport.Child sessionTransport); + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java index 0c2069f3..6c3442d1 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java @@ -3,7 +3,6 @@ */ package io.modelcontextprotocol.spec; -import java.util.Map; import java.util.function.Function; import reactor.core.publisher.Mono; @@ -20,8 +19,12 @@ default Mono connect(Function, Mono> requestHandlers, - Map notificationHandlers); + void setSessionFactory(ServerMcpSession.Factory sessionFactory); + interface Child extends McpTransport { + @Override + default Mono connect(Function, Mono> handler) { + throw new IllegalStateException("Server transport does not support connect method"); + } + } } From 0823591456e6c7210e4c3fc23045499a009fb9b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 12 Mar 2025 16:45:05 +0100 Subject: [PATCH 03/20] Aim for smoother transition to new APIs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- ...=> WebFluxSseServerTransportProvider.java} | 107 +- .../WebFluxSseIntegrationTests.java | 6 +- .../server/WebFluxSseMcpAsyncServerTests.java | 6 +- .../server/WebFluxSseMcpSyncServerTests.java | 8 +- .../server/McpAsyncServer.java | 1629 +++++++++++++---- .../server/McpServer.java | 45 +- .../server/McpSyncServer.java | 5 + .../spec/ClientMcpTransport.java | 4 +- .../spec/McpClientTransport.java | 12 + .../spec/McpServerTransport.java | 5 + .../spec/McpServerTransportProvider.java | 32 + .../spec/McpTransport.java | 6 +- .../spec/ServerMcpExchange.java | 4 + .../spec/ServerMcpSession.java | 88 +- .../spec/ServerMcpTransport.java | 19 +- 15 files changed, 1463 insertions(+), 513 deletions(-) rename mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/{WebFluxSseServerTransport.java => WebFluxSseServerTransportProvider.java} (80%) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java similarity index 80% rename from mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java rename to mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 65b02bba..732616dc 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -1,13 +1,15 @@ package io.modelcontextprotocol.server.transport; import java.io.IOException; -import java.util.UUID; +import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.spec.ServerMcpSession; import io.modelcontextprotocol.spec.ServerMcpTransport; import io.modelcontextprotocol.util.Assert; @@ -61,9 +63,10 @@ * @see ServerMcpTransport * @see ServerSentEvent */ -public class WebFluxSseServerTransport implements ServerMcpTransport { +public class WebFluxSseServerTransportProvider implements McpServerTransportProvider { - private static final Logger logger = LoggerFactory.getLogger(WebFluxSseServerTransport.class); + private static final Logger logger = LoggerFactory.getLogger( + WebFluxSseServerTransportProvider.class); /** * Event type for JSON-RPC messages sent through the SSE connection. @@ -95,18 +98,6 @@ public class WebFluxSseServerTransport implements ServerMcpTransport { */ private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); - // FIXME: This is a bit clumsy. The McpAsyncServer handles global notifications - // using the transport and we need access to child transports for each session to - // use the sendMessage method. Ideally, the particular transport would be an - // abstraction of a specialized session that can handle only notifications and we - // could delegate to all child sessions without directly going through the transport. - // The conversion from a notification to message happens both in McpAsyncServer - // and in ServerMcpSession and it would be beneficial to have a unified interface - // for both. An MCP server implementation can use both McpServerExchange and - // Mcp(Sync|Async)Server to send notifications so the capability needs to lie in - // both places. - private final ConcurrentHashMap sessionTransports = new ConcurrentHashMap<>(); - /** * Flag indicating if the transport is shutting down. */ @@ -121,7 +112,7 @@ public class WebFluxSseServerTransport implements ServerMcpTransport { * setup. Must not be null. * @throws IllegalArgumentException if either parameter is null */ - public WebFluxSseServerTransport(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { + public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); @@ -145,7 +136,7 @@ public WebFluxSseServerTransport(ObjectMapper objectMapper, String messageEndpoi * setup. Must not be null. * @throws IllegalArgumentException if either parameter is null */ - public WebFluxSseServerTransport(ObjectMapper objectMapper, String messageEndpoint) { + public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); } @@ -167,12 +158,13 @@ public void setSessionFactory(ServerMcpSession.Factory sessionFactory) { *

  • Attempts to send the event to all active sessions
  • *
  • Tracks and reports any delivery failures
  • * - * @param message The JSON-RPC message to broadcast + * @param method The JSON-RPC method to send to clients + * @param params The method parameters to send to clients * @return A Mono that completes when the message has been sent to all sessions, or * errors if any session fails to receive the message */ @Override - public Mono sendMessage(McpSchema.JSONRPCMessage message) { + public Mono notifyClients(String method, Map params) { if (sessions.isEmpty()) { logger.debug("No active sessions to broadcast message to"); return Mono.empty(); @@ -180,29 +172,15 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); - return Flux.fromStream(sessionTransports.values().stream()) - .flatMap(session -> session.sendMessage(message) - .doOnError(e -> logger.error("Failed to " + "send message to session {}: {}", session.sessionId, + return Flux.fromStream(sessions.values().stream()) + .flatMap(session -> session.sendNotification(method, params) + .doOnError(e -> logger.error("Failed to " + "send message to session " + + "{}: {}", session.getId(), e.getMessage())) .onErrorComplete()) .then(); } - /** - * Converts data from one type to another using the configured ObjectMapper. This - * method is primarily used for converting between different representations of - * JSON-RPC message data. - * @param The target type to convert to - * @param data The source data to convert - * @param typeRef Type reference describing the target type - * @return The converted data - * @throws IllegalArgumentException if the conversion fails - */ - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return this.objectMapper.convertValue(data, typeRef); - } - /** * Initiates a graceful shutdown of the transport. This method ensures all active * sessions are properly closed and cleaned up. @@ -265,13 +243,13 @@ private Mono handleSseConnection(ServerRequest request) { return ServerResponse.ok() .contentType(MediaType.TEXT_EVENT_STREAM) .body(Flux.>create(sink -> { - String sessionId = UUID.randomUUID().toString(); - logger.debug("Creating new SSE connection for session: {}", sessionId); - WebFluxMcpSessionTransport - sessionTransport = new WebFluxMcpSessionTransport(sessionId, sink); + WebFluxMcpSessionTransport sessionTransport = new WebFluxMcpSessionTransport(sink); + + ServerMcpSession session = sessionFactory.create(sessionTransport); + String sessionId = session.getId(); - sessions.put(sessionId, sessionFactory.create(sessionTransport)); - sessionTransports.put(sessionId, sessionTransport); + logger.debug("Created new SSE connection for session: {}", sessionId); + sessions.put(sessionId, session); // Send initial endpoint event logger.debug("Sending initial endpoint event to session: {}", sessionId); @@ -328,14 +306,47 @@ private Mono handleMessage(ServerRequest request) { }); } - private class WebFluxMcpSessionTransport implements ServerMcpTransport.Child { + /* + Current: + + framework layer: + var transport = new WebFluxSseServerTransport(objectMapper, "/mcp", "/sse"); + McpServer.async(ServerMcpTransport transport) + + client connects -> + WebFluxSseServerTransport creates a: + - var sessionTransport = WebFluxMcpSessionTransport + - ServerMcpSession(sessionId, sessionTransport) + + WebFluxSseServerTransport IS_A ServerMcpTransport IS_A McpTransport + WebFluxMcpSessionTransport IS_A ServerMcpSessionTransport IS_A McpTransport + + McpTransport contains connect() which should be removed + ClientMcpTransport should have connect() + ServerMcpTransport should have setSessionFactory() + + Possible Future: + var transportProvider = new WebFluxSseServerTransport(objectMapper, "/mcp", "/sse"); + WebFluxSseServerTransport IS_A ServerMcpTransportProvider ? + ServerMcpTransportProvider creates ServerMcpTransport + + // disadvantage - too much breaks, e.g. + McpServer.async(ServerMcpTransportProvider transportProvider) + + // advantage + + ClientMcpTransport and ServerMcpTransport BOTH represent 1:1 relationship + + + + + */ - final String sessionId; + private class WebFluxMcpSessionTransport implements McpServerTransport { private final FluxSink> sink; - public WebFluxMcpSessionTransport(String sessionId, FluxSink> sink) { - this.sessionId = sessionId; + public WebFluxMcpSessionTransport(FluxSink> sink) { this.sink = sink; } @@ -363,7 +374,7 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { @Override public T unmarshalFrom(Object data, TypeReference typeRef) { - return WebFluxSseServerTransport.this.unmarshalFrom(data, typeRef); + return objectMapper.convertValue(data, typeRef); } @Override diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 4cd24c62..3df80db8 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -16,7 +16,7 @@ import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServerFeatures; -import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -55,14 +55,14 @@ public class WebFluxSseIntegrationTests { private DisposableServer httpServer; - private WebFluxSseServerTransport mcpServerTransport; + private WebFluxSseServerTransportProvider mcpServerTransport; ConcurrentHashMap clientBulders = new ConcurrentHashMap<>(); @BeforeEach public void before() { - this.mcpServerTransport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + this.mcpServerTransport = new WebFluxSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransport.getRouterFunction()); ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java index 1ed0d99b..34f4b689 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java @@ -5,7 +5,7 @@ package io.modelcontextprotocol.server; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; import io.modelcontextprotocol.spec.ServerMcpTransport; import org.junit.jupiter.api.Timeout; import reactor.netty.DisposableServer; @@ -16,7 +16,7 @@ import org.springframework.web.reactive.function.server.RouterFunctions; /** - * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransport}. + * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransportProvider}. * * @author Christian Tzolov */ @@ -31,7 +31,7 @@ class WebFluxSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { @Override protected ServerMcpTransport createMcpTransport() { - var transport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + var transport = new WebFluxSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); HttpHandler httpHandler = RouterFunctions.toHttpHandler(transport.getRouterFunction()); ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java index 4db00dd4..2cf1087d 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java @@ -5,7 +5,7 @@ package io.modelcontextprotocol.server; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; import io.modelcontextprotocol.spec.ServerMcpTransport; import org.junit.jupiter.api.Timeout; import reactor.netty.DisposableServer; @@ -16,7 +16,7 @@ import org.springframework.web.reactive.function.server.RouterFunctions; /** - * Tests for {@link McpSyncServer} using {@link WebFluxSseServerTransport}. + * Tests for {@link McpSyncServer} using {@link WebFluxSseServerTransportProvider}. * * @author Christian Tzolov */ @@ -29,11 +29,11 @@ class WebFluxSseMcpSyncServerTests extends AbstractMcpSyncServerTests { private DisposableServer httpServer; - private WebFluxSseServerTransport transport; + private WebFluxSseServerTransportProvider transport; @Override protected ServerMcpTransport createMcpTransport() { - transport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + transport = new WebFluxSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); return transport; } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 4072e504..d565cb9e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -4,25 +4,29 @@ package io.modelcontextprotocol.server; +import java.time.Duration; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; import java.util.function.Function; import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.DefaultMcpSession; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.spec.ServerMcpSession; -import io.modelcontextprotocol.spec.ServerMcpTransport; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.ServerMcpTransport; import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -74,125 +78,31 @@ public class McpAsyncServer { private static final Logger logger = LoggerFactory.getLogger(McpAsyncServer.class); - private final ServerMcpTransport transport; - - private final McpSchema.ServerCapabilities serverCapabilities; - - private final McpSchema.Implementation serverInfo; - - /** - * Thread-safe list of tool handlers that can be modified at runtime. - */ - private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); - - private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); + private final McpAsyncServer delegate; - private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); - - private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); - - private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; - - /** - * Supported protocol versions. - */ - private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); + McpAsyncServer() { + this.delegate = null; + } /** * Create a new McpAsyncServer with the given transport and capabilities. * @param mcpTransport The transport layer implementation for MCP communication. * @param features The MCP server supported features. */ + @Deprecated McpAsyncServer(ServerMcpTransport mcpTransport, McpServerFeatures.Async features) { - this.serverInfo = features.serverInfo(); - this.serverCapabilities = features.serverCapabilities(); - this.tools.addAll(features.tools()); - this.resources.putAll(features.resources()); - this.resourceTemplates.addAll(features.resourceTemplates()); - this.prompts.putAll(features.prompts()); - - Map> requestHandlers = new HashMap<>(); - - // Initialize request handlers for standard MCP methods - - // Ping MUST respond with an empty data, but not NULL response. - requestHandlers.put(McpSchema.METHOD_PING, (exchange, params) -> Mono.just("")); - - // Add tools API handlers if the tool capability is enabled - if (this.serverCapabilities.tools() != null) { - requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); - } - - // Add resources API handlers if provided - if (this.serverCapabilities.resources() != null) { - requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); - requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); - } - - // Add prompts API handlers if provider exists - if (this.serverCapabilities.prompts() != null) { - requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); - } - - // Add logging API handlers if the logging capability is enabled - if (this.serverCapabilities.logging() != null) { - requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); - } - - Map notificationHandlers = new HashMap<>(); - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); - - List, Mono>> rootsChangeConsumers = features.rootsChangeConsumers(); - - if (Utils.isEmpty(rootsChangeConsumers)) { - rootsChangeConsumers = List.of((roots) -> Mono.fromRunnable(() -> logger - .warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); - } - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, - asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - - this.transport = mcpTransport; - mcpTransport.setSessionFactory(transport -> new ServerMcpSession(transport, - this::asyncInitializeRequestHandler, requestHandlers, notificationHandlers)); + this.delegate = new LegacyAsyncServer(mcpTransport, features); } - // --------------------------------------- - // Lifecycle Management - // --------------------------------------- - private Mono asyncInitializeRequestHandler( - ServerMcpSession.ClientInitConsumer initConsumer, McpSchema.InitializeRequest initializeRequest) { - return Mono.defer(() -> { - - - initConsumer.init(initializeRequest.capabilities(), initializeRequest.clientInfo()); - - logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", - initializeRequest.protocolVersion(), initializeRequest.capabilities(), - initializeRequest.clientInfo()); - - // The server MUST respond with the highest protocol version it supports if - // it does not support the requested (e.g. Client) version. - String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); - - if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { - // If the server supports the requested protocol version, it MUST respond - // with the same version. - serverProtocolVersion = initializeRequest.protocolVersion(); - } - else { - logger.warn( - "Client requested unsupported protocol version: {}, so the server will sugggest the {} version instead", - initializeRequest.protocolVersion(), serverProtocolVersion); - } - - return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, - this.serverInfo, null)); - }); + /** + * Create a new McpAsyncServer with the given transport and capabilities. + * @param mcpTransportProvider The transport layer implementation for MCP communication. + * @param features The MCP server supported features. + */ + McpAsyncServer(McpServerTransportProvider mcpTransportProvider, + ObjectMapper objectMapper, + McpServerFeatures.Async features) { + this.delegate = new AsyncServerImpl(mcpTransportProvider, objectMapper, features); } /** @@ -200,7 +110,7 @@ private Mono asyncInitializeRequestHandler( * @return The server capabilities */ public McpSchema.ServerCapabilities getServerCapabilities() { - return this.serverCapabilities; + return this.delegate.getServerCapabilities(); } /** @@ -208,25 +118,27 @@ public McpSchema.ServerCapabilities getServerCapabilities() { * @return The server implementation details */ public McpSchema.Implementation getServerInfo() { - return this.serverInfo; + return this.delegate.getServerInfo(); } /** * Get the client capabilities that define the supported features and functionality. * @return The client capabilities + * @deprecated This will be removed in 0.9.0 */ @Deprecated public ClientCapabilities getClientCapabilities() { - throw new IllegalStateException("This method is deprecated and should not be called"); + return this.delegate.getClientCapabilities(); } /** * Get the client implementation information. * @return The client implementation details + * @deprecated This will be removed in 0.9.0 */ @Deprecated public McpSchema.Implementation getClientInfo() { - throw new IllegalStateException("This method is deprecated and should not be called"); + return this.delegate.getClientInfo(); } /** @@ -234,47 +146,34 @@ public McpSchema.Implementation getClientInfo() { * @return A Mono that completes when the server has been closed */ public Mono closeGracefully() { - return this.transport.closeGracefully(); + return this.delegate.closeGracefully(); } /** * Close the server immediately. */ public void close() { - this.transport.close(); + this.delegate.close(); } - private static final TypeReference LIST_ROOTS_RESULT_TYPE_REF = new TypeReference<>() { - }; - /** * Retrieves the list of all roots provided by the client. * @return A Mono that emits the list of roots result. */ + @Deprecated public Mono listRoots() { - return this.listRoots(null); + return this.delegate.listRoots(null); } /** * Retrieves a paginated list of roots provided by the server. * @param cursor Optional pagination cursor from a previous list request * @return A Mono that emits the list of roots result containing + * @deprecated This will be removed in 0.9.0 */ @Deprecated public Mono listRoots(String cursor) { - return Mono.error(new RuntimeException("Not implemented")); - } - - private ServerMcpSession.NotificationHandler asyncRootsListChangedNotificationHandler( - List, Mono>> rootsChangeConsumers) { - return (exchange, - params) -> listRoots().flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) - .flatMap(consumer -> consumer.apply(listRootsResult.roots())) - .onErrorResume(error -> { - logger.error("Error handling roots list change notification", error); - return Mono.empty(); - }) - .then()); + return this.delegate.listRoots(cursor); } // --------------------------------------- @@ -287,34 +186,7 @@ private ServerMcpSession.NotificationHandler asyncRootsListChangedNotificationHa * @return Mono that completes when clients have been notified of the change */ public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistration) { - if (toolRegistration == null) { - return Mono.error(new McpError("Tool registration must not be null")); - } - if (toolRegistration.tool() == null) { - return Mono.error(new McpError("Tool must not be null")); - } - if (toolRegistration.call() == null) { - return Mono.error(new McpError("Tool call handler must not be null")); - } - if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); - } - - return Mono.defer(() -> { - // Check for duplicate tool names - if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolRegistration.tool().name()))) { - return Mono - .error(new McpError("Tool with name '" + toolRegistration.tool().name() + "' already exists")); - } - - this.tools.add(toolRegistration); - logger.debug("Added tool handler: {}", toolRegistration.tool().name()); - - if (this.serverCapabilities.tools().listChanged()) { - return notifyToolsListChanged(); - } - return Mono.empty(); - }); + return this.delegate.addTool(toolRegistration); } /** @@ -323,24 +195,7 @@ public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistrati * @return Mono that completes when clients have been notified of the change */ public Mono removeTool(String toolName) { - if (toolName == null) { - return Mono.error(new McpError("Tool name must not be null")); - } - if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); - } - - return Mono.defer(() -> { - boolean removed = this.tools.removeIf(toolRegistration -> toolRegistration.tool().name().equals(toolName)); - if (removed) { - logger.debug("Removed tool handler: {}", toolName); - if (this.serverCapabilities.tools().listChanged()) { - return notifyToolsListChanged(); - } - return Mono.empty(); - } - return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); - }); + return this.delegate.removeTool(toolName); } /** @@ -348,36 +203,7 @@ public Mono removeTool(String toolName) { * @return A Mono that completes when all clients have been notified */ public Mono notifyToolsListChanged() { - McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, - McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); - return this.transport.sendMessage(jsonrpcNotification); - } - - private ServerMcpSession.RequestHandler toolsListRequestHandler() { - return (exchange, params) -> { - List tools = this.tools.stream().map(McpServerFeatures.AsyncToolRegistration::tool).toList(); - - return Mono.just(new McpSchema.ListToolsResult(tools, null)); - }; - } - - private ServerMcpSession.RequestHandler toolsCallRequestHandler() { - return (exchange, params) -> { - McpSchema.CallToolRequest callToolRequest = transport.unmarshalFrom(params, - new TypeReference() { - }); - - Optional toolRegistration = this.tools.stream() - .filter(tr -> callToolRequest.name().equals(tr.tool().name())) - .findAny(); - - if (toolRegistration.isEmpty()) { - return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); - } - - return toolRegistration.map(tool -> tool.call().apply(callToolRequest.arguments())) - .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); - }; + return this.delegate.notifyToolsListChanged(); } // --------------------------------------- @@ -390,25 +216,7 @@ private ServerMcpSession.RequestHandler toolsCallRequestHandler( * @return Mono that completes when clients have been notified of the change */ public Mono addResource(McpServerFeatures.AsyncResourceRegistration resourceHandler) { - if (resourceHandler == null || resourceHandler.resource() == null) { - return Mono.error(new McpError("Resource must not be null")); - } - - if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); - } - - return Mono.defer(() -> { - if (this.resources.putIfAbsent(resourceHandler.resource().uri(), resourceHandler) != null) { - return Mono - .error(new McpError("Resource with URI '" + resourceHandler.resource().uri() + "' already exists")); - } - logger.debug("Added resource handler: {}", resourceHandler.resource().uri()); - if (this.serverCapabilities.resources().listChanged()) { - return notifyResourcesListChanged(); - } - return Mono.empty(); - }); + return this.delegate.addResource(resourceHandler); } /** @@ -417,24 +225,7 @@ public Mono addResource(McpServerFeatures.AsyncResourceRegistration resour * @return Mono that completes when clients have been notified of the change */ public Mono removeResource(String resourceUri) { - if (resourceUri == null) { - return Mono.error(new McpError("Resource URI must not be null")); - } - if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); - } - - return Mono.defer(() -> { - McpServerFeatures.AsyncResourceRegistration removed = this.resources.remove(resourceUri); - if (removed != null) { - logger.debug("Removed resource handler: {}", resourceUri); - if (this.serverCapabilities.resources().listChanged()) { - return notifyResourcesListChanged(); - } - return Mono.empty(); - } - return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); - }); + return this.delegate.removeResource(resourceUri); } /** @@ -442,38 +233,7 @@ public Mono removeResource(String resourceUri) { * @return A Mono that completes when all clients have been notified */ public Mono notifyResourcesListChanged() { - McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, - McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); - return this.transport.sendMessage(jsonrpcNotification); - } - - private ServerMcpSession.RequestHandler resourcesListRequestHandler() { - return (exchange, params) -> { - var resourceList = this.resources.values() - .stream() - .map(McpServerFeatures.AsyncResourceRegistration::resource) - .toList(); - return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); - }; - } - - private ServerMcpSession.RequestHandler resourceTemplateListRequestHandler() { - return (exchange, params) -> Mono.just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null)); - - } - - private ServerMcpSession.RequestHandler resourcesReadRequestHandler() { - return (exchange, params) -> { - McpSchema.ReadResourceRequest resourceRequest = transport.unmarshalFrom(params, - new TypeReference() { - }); - var resourceUri = resourceRequest.uri(); - McpServerFeatures.AsyncResourceRegistration registration = this.resources.get(resourceUri); - if (registration != null) { - return registration.readHandler().apply(resourceRequest); - } - return Mono.error(new McpError("Resource not found: " + resourceUri)); - }; + return this.delegate.notifyResourcesListChanged(); } // --------------------------------------- @@ -486,31 +246,7 @@ private ServerMcpSession.RequestHandler resourcesR * @return Mono that completes when clients have been notified of the change */ public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegistration) { - if (promptRegistration == null) { - return Mono.error(new McpError("Prompt registration must not be null")); - } - if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); - } - - return Mono.defer(() -> { - McpServerFeatures.AsyncPromptRegistration registration = this.prompts - .putIfAbsent(promptRegistration.prompt().name(), promptRegistration); - if (registration != null) { - return Mono.error( - new McpError("Prompt with name '" + promptRegistration.prompt().name() + "' already exists")); - } - - logger.debug("Added prompt handler: {}", promptRegistration.prompt().name()); - - // Servers that declared the listChanged capability SHOULD send a - // notification, - // when the list of available prompts changes - if (this.serverCapabilities.prompts().listChanged()) { - return notifyPromptsListChanged(); - } - return Mono.empty(); - }); + return this.delegate.addPrompt(promptRegistration); } /** @@ -519,27 +255,7 @@ public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegi * @return Mono that completes when clients have been notified of the change */ public Mono removePrompt(String promptName) { - if (promptName == null) { - return Mono.error(new McpError("Prompt name must not be null")); - } - if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); - } - - return Mono.defer(() -> { - McpServerFeatures.AsyncPromptRegistration removed = this.prompts.remove(promptName); - - if (removed != null) { - logger.debug("Removed prompt handler: {}", promptName); - // Servers that declared the listChanged capability SHOULD send a - // notification, when the list of available prompts changes - if (this.serverCapabilities.prompts().listChanged()) { - return this.notifyPromptsListChanged(); - } - return Mono.empty(); - } - return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); - }); + return this.delegate.removePrompt(promptName); } /** @@ -547,41 +263,7 @@ public Mono removePrompt(String promptName) { * @return A Mono that completes when all clients have been notified */ public Mono notifyPromptsListChanged() { - McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, - McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); - return this.transport.sendMessage(jsonrpcNotification); - } - - private ServerMcpSession.RequestHandler promptsListRequestHandler() { - return (exchange, params) -> { - // TODO: Implement pagination - // McpSchema.PaginatedRequest request = transport.unmarshalFrom(params, - // new TypeReference() { - // }); - - var promptList = this.prompts.values() - .stream() - .map(McpServerFeatures.AsyncPromptRegistration::prompt) - .toList(); - - return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); - }; - } - - private ServerMcpSession.RequestHandler promptsGetRequestHandler() { - return (exchange, params) -> { - McpSchema.GetPromptRequest promptRequest = transport.unmarshalFrom(params, - new TypeReference() { - }); - - // Implement prompt retrieval logic here - McpServerFeatures.AsyncPromptRegistration registration = this.prompts.get(promptRequest.name()); - if (registration == null) { - return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); - } - - return registration.promptHandler().apply(promptRequest); - }; + return this.delegate.notifyPromptsListChanged(); } // --------------------------------------- @@ -595,43 +277,12 @@ private ServerMcpSession.RequestHandler promptsGetReq * @return A Mono that completes when the notification has been sent */ public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { - - if (loggingMessageNotification == null) { - return Mono.error(new McpError("Logging message must not be null")); - } - - Map params = this.transport.unmarshalFrom(loggingMessageNotification, - new TypeReference>() { - }); - - if (loggingMessageNotification.level().level() < minLoggingLevel.level()) { - return Mono.empty(); - } - - McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, - McpSchema.METHOD_NOTIFICATION_MESSAGE, params); - return this.transport.sendMessage(jsonrpcNotification); - } - - /** - * Handles requests to set the minimum logging level. Messages below this level will - * not be sent. - * @return A handler that processes logging level change requests - */ - private ServerMcpSession.RequestHandler setLoggerRequestHandler() { - return (exchange, params) -> { - this.minLoggingLevel = transport.unmarshalFrom(params, new TypeReference() { - }); - - return Mono.empty(); - }; + return this.delegate.loggingNotification(loggingMessageNotification); } // --------------------------------------- // Sampling // --------------------------------------- - private static final TypeReference CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() { - }; /** * Create a new message using the sampling capabilities of the client. The Model @@ -651,10 +302,11 @@ private ServerMcpSession.RequestHandler setLoggerRequestHandler() { * @see Sampling * Specification + * @deprecated This will be removed in 0.9.0 */ @Deprecated public Mono createMessage(McpSchema.CreateMessageRequest createMessageRequest) { - return Mono.error(new RuntimeException("Not implemented")); + return this.delegate.createMessage(createMessageRequest); } /** @@ -663,7 +315,1198 @@ public Mono createMessage(McpSchema.CreateMessage * @param protocolVersions the Client supported protocol versions. */ void setProtocolVersions(List protocolVersions) { - this.protocolVersions = protocolVersions; + this.delegate.setProtocolVersions(protocolVersions); } + private static class AsyncServerImpl extends McpAsyncServer { + private final McpServerTransportProvider mcpTransportProvider; + + private final ObjectMapper objectMapper; + + private final McpSchema.ServerCapabilities serverCapabilities; + + private final McpSchema.Implementation serverInfo; + + /** + * Thread-safe list of tool handlers that can be modified at runtime. + */ + private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); + + private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); + + private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); + + private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; + + /** + * Supported protocol versions. + */ + private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); + + /** + * Create a new McpAsyncServer with the given transport and capabilities. + * @param mcpTransportProvider The transport layer implementation for MCP communication. + * @param features The MCP server supported features. + */ + AsyncServerImpl(McpServerTransportProvider mcpTransportProvider, + ObjectMapper objectMapper, + McpServerFeatures.Async features) { + this.mcpTransportProvider = mcpTransportProvider; + this.objectMapper = objectMapper; + this.serverInfo = features.serverInfo(); + this.serverCapabilities = features.serverCapabilities(); + this.tools.addAll(features.tools()); + this.resources.putAll(features.resources()); + this.resourceTemplates.addAll(features.resourceTemplates()); + this.prompts.putAll(features.prompts()); + + Map> requestHandlers = new HashMap<>(); + + // Initialize request handlers for standard MCP methods + + // Ping MUST respond with an empty data, but not NULL response. + requestHandlers.put(McpSchema.METHOD_PING, (exchange, params) -> Mono.just("")); + + // Add tools API handlers if the tool capability is enabled + if (this.serverCapabilities.tools() != null) { + requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); + } + + // Add resources API handlers if provided + if (this.serverCapabilities.resources() != null) { + requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); + } + + // Add prompts API handlers if provider exists + if (this.serverCapabilities.prompts() != null) { + requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); + } + + // Add logging API handlers if the logging capability is enabled + if (this.serverCapabilities.logging() != null) { + requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); + } + + Map notificationHandlers = new HashMap<>(); + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); + + List, Mono>> rootsChangeConsumers = features.rootsChangeConsumers(); + + if (Utils.isEmpty(rootsChangeConsumers)) { + rootsChangeConsumers = List.of((roots) -> Mono.fromRunnable(() -> logger + .warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); + } + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, + asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); + + mcpTransportProvider.setSessionFactory(transport -> new ServerMcpSession( + UUID.randomUUID().toString(), + transport, + this::asyncInitializeRequestHandler, + Mono::empty, + requestHandlers, + notificationHandlers)); + } + + // --------------------------------------- + // Lifecycle Management + // --------------------------------------- + private Mono asyncInitializeRequestHandler(McpSchema.InitializeRequest initializeRequest) { + return Mono.defer(() -> { + logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", + initializeRequest.protocolVersion(), initializeRequest.capabilities(), + initializeRequest.clientInfo()); + + // The server MUST respond with the highest protocol version it supports if + // it does not support the requested (e.g. Client) version. + String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); + + if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { + // If the server supports the requested protocol version, it MUST respond + // with the same version. + serverProtocolVersion = initializeRequest.protocolVersion(); + } + else { + logger.warn( + "Client requested unsupported protocol version: {}, so the server will sugggest the {} version instead", + initializeRequest.protocolVersion(), serverProtocolVersion); + } + + return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, + this.serverInfo, null)); + }); + } + + /** + * Get the server capabilities that define the supported features and functionality. + * @return The server capabilities + */ + public McpSchema.ServerCapabilities getServerCapabilities() { + return this.serverCapabilities; + } + + /** + * Get the server implementation information. + * @return The server implementation details + */ + public McpSchema.Implementation getServerInfo() { + return this.serverInfo; + } + + /** + * Get the client capabilities that define the supported features and functionality. + * @return The client capabilities + */ + @Deprecated + public ClientCapabilities getClientCapabilities() { + throw new IllegalStateException("This method is deprecated and should not be called"); + } + + /** + * Get the client implementation information. + * @return The client implementation details + */ + @Deprecated + public McpSchema.Implementation getClientInfo() { + throw new IllegalStateException("This method is deprecated and should not be called"); + } + + /** + * Gracefully closes the server, allowing any in-progress operations to complete. + * @return A Mono that completes when the server has been closed + */ + public Mono closeGracefully() { + return this.mcpTransportProvider.closeGracefully(); + } + + /** + * Close the server immediately. + */ + public void close() { + this.mcpTransportProvider.close(); + } + + /** + * Retrieves the list of all roots provided by the client. + * @return A Mono that emits the list of roots result. + */ + @Deprecated + public Mono listRoots() { + return this.listRoots(null); + } + + /** + * Retrieves a paginated list of roots provided by the server. + * @param cursor Optional pagination cursor from a previous list request + * @return A Mono that emits the list of roots result containing + */ + @Deprecated + public Mono listRoots(String cursor) { + return Mono.error(new RuntimeException("Not implemented")); + } + + private ServerMcpSession.NotificationHandler asyncRootsListChangedNotificationHandler( + List, Mono>> rootsChangeConsumers) { + return (exchange, + params) -> listRoots().flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) + .flatMap(consumer -> consumer.apply(listRootsResult.roots())) + .onErrorResume(error -> { + logger.error("Error handling roots list change notification", error); + return Mono.empty(); + }) + .then()); + } + + // --------------------------------------- + // Tool Management + // --------------------------------------- + + /** + * Add a new tool registration at runtime. + * @param toolRegistration The tool registration to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistration) { + if (toolRegistration == null) { + return Mono.error(new McpError("Tool registration must not be null")); + } + if (toolRegistration.tool() == null) { + return Mono.error(new McpError("Tool must not be null")); + } + if (toolRegistration.call() == null) { + return Mono.error(new McpError("Tool call handler must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + // Check for duplicate tool names + if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolRegistration.tool().name()))) { + return Mono + .error(new McpError("Tool with name '" + toolRegistration.tool().name() + "' already exists")); + } + + this.tools.add(toolRegistration); + logger.debug("Added tool handler: {}", toolRegistration.tool().name()); + + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + return Mono.empty(); + }); + } + + /** + * Remove a tool handler at runtime. + * @param toolName The name of the tool handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeTool(String toolName) { + if (toolName == null) { + return Mono.error(new McpError("Tool name must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + boolean removed = this.tools.removeIf(toolRegistration -> toolRegistration.tool().name().equals(toolName)); + if (removed) { + logger.debug("Removed tool handler: {}", toolName); + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); + }); + } + + /** + * Notifies clients that the list of available tools has changed. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyToolsListChanged() { + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); + } + + private ServerMcpSession.RequestHandler toolsListRequestHandler() { + return (exchange, params) -> { + List tools = this.tools.stream().map(McpServerFeatures.AsyncToolRegistration::tool).toList(); + + return Mono.just(new McpSchema.ListToolsResult(tools, null)); + }; + } + + private ServerMcpSession.RequestHandler toolsCallRequestHandler() { + return (exchange, params) -> { + McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + + Optional toolRegistration = this.tools.stream() + .filter(tr -> callToolRequest.name().equals(tr.tool().name())) + .findAny(); + + if (toolRegistration.isEmpty()) { + return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); + } + + return toolRegistration.map(tool -> tool.call().apply(callToolRequest.arguments())) + .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); + }; + } + + // --------------------------------------- + // Resource Management + // --------------------------------------- + + /** + * Add a new resource handler at runtime. + * @param resourceHandler The resource handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addResource(McpServerFeatures.AsyncResourceRegistration resourceHandler) { + if (resourceHandler == null || resourceHandler.resource() == null) { + return Mono.error(new McpError("Resource must not be null")); + } + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + if (this.resources.putIfAbsent(resourceHandler.resource().uri(), resourceHandler) != null) { + return Mono + .error(new McpError("Resource with URI '" + resourceHandler.resource().uri() + "' already exists")); + } + logger.debug("Added resource handler: {}", resourceHandler.resource().uri()); + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + }); + } + + /** + * Remove a resource handler at runtime. + * @param resourceUri The URI of the resource handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeResource(String resourceUri) { + if (resourceUri == null) { + return Mono.error(new McpError("Resource URI must not be null")); + } + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncResourceRegistration removed = this.resources.remove(resourceUri); + if (removed != null) { + logger.debug("Removed resource handler: {}", resourceUri); + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); + }); + } + + /** + * Notifies clients that the list of available resources has changed. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyResourcesListChanged() { + McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); + } + + private ServerMcpSession.RequestHandler resourcesListRequestHandler() { + return (exchange, params) -> { + var resourceList = this.resources.values() + .stream() + .map(McpServerFeatures.AsyncResourceRegistration::resource) + .toList(); + return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); + }; + } + + private ServerMcpSession.RequestHandler resourceTemplateListRequestHandler() { + return (exchange, params) -> Mono.just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null)); + + } + + private ServerMcpSession.RequestHandler resourcesReadRequestHandler() { + return (exchange, params) -> { + McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + var resourceUri = resourceRequest.uri(); + McpServerFeatures.AsyncResourceRegistration registration = this.resources.get(resourceUri); + if (registration != null) { + return registration.readHandler().apply(resourceRequest); + } + return Mono.error(new McpError("Resource not found: " + resourceUri)); + }; + } + + // --------------------------------------- + // Prompt Management + // --------------------------------------- + + /** + * Add a new prompt handler at runtime. + * @param promptRegistration The prompt handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegistration) { + if (promptRegistration == null) { + return Mono.error(new McpError("Prompt registration must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new McpError("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncPromptRegistration registration = this.prompts + .putIfAbsent(promptRegistration.prompt().name(), promptRegistration); + if (registration != null) { + return Mono.error( + new McpError("Prompt with name '" + promptRegistration.prompt().name() + "' already exists")); + } + + logger.debug("Added prompt handler: {}", promptRegistration.prompt().name()); + + // Servers that declared the listChanged capability SHOULD send a + // notification, + // when the list of available prompts changes + if (this.serverCapabilities.prompts().listChanged()) { + return notifyPromptsListChanged(); + } + return Mono.empty(); + }); + } + + /** + * Remove a prompt handler at runtime. + * @param promptName The name of the prompt handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removePrompt(String promptName) { + if (promptName == null) { + return Mono.error(new McpError("Prompt name must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new McpError("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncPromptRegistration removed = this.prompts.remove(promptName); + + if (removed != null) { + logger.debug("Removed prompt handler: {}", promptName); + // Servers that declared the listChanged capability SHOULD send a + // notification, when the list of available prompts changes + if (this.serverCapabilities.prompts().listChanged()) { + return this.notifyPromptsListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); + }); + } + + /** + * Notifies clients that the list of available prompts has changed. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyPromptsListChanged() { + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); + } + + private ServerMcpSession.RequestHandler promptsListRequestHandler() { + return (exchange, params) -> { + // TODO: Implement pagination + // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, + // new TypeReference() { + // }); + + var promptList = this.prompts.values() + .stream() + .map(McpServerFeatures.AsyncPromptRegistration::prompt) + .toList(); + + return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); + }; + } + + private ServerMcpSession.RequestHandler promptsGetRequestHandler() { + return (exchange, params) -> { + McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + + // Implement prompt retrieval logic here + McpServerFeatures.AsyncPromptRegistration registration = this.prompts.get(promptRequest.name()); + if (registration == null) { + return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); + } + + return registration.promptHandler().apply(promptRequest); + }; + } + + // --------------------------------------- + // Logging Management + // --------------------------------------- + + /** + * Send a logging message notification to all connected clients. Messages below the + * current minimum logging level will be filtered out. + * @param loggingMessageNotification The logging message to send + * @return A Mono that completes when the notification has been sent + */ + public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { + + if (loggingMessageNotification == null) { + return Mono.error(new McpError("Logging message must not be null")); + } + + Map params = this.objectMapper.convertValue(loggingMessageNotification, + new TypeReference>() { + }); + + if (loggingMessageNotification.level().level() < minLoggingLevel.level()) { + return Mono.empty(); + } + + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_MESSAGE, params); + } + + /** + * Handles requests to set the minimum logging level. Messages below this level will + * not be sent. + * @return A handler that processes logging level change requests + */ + private ServerMcpSession.RequestHandler setLoggerRequestHandler() { + return (exchange, params) -> { + this.minLoggingLevel = objectMapper.convertValue(params, new TypeReference() { + }); + + return Mono.empty(); + }; + } + + // --------------------------------------- + // Sampling + // --------------------------------------- + + /** + * Create a new message using the sampling capabilities of the client. The Model + * Context Protocol (MCP) provides a standardized way for servers to request LLM + * sampling (“completions” or “generations”) from language models via clients. This + * flow allows clients to maintain control over model access, selection, and + * permissions while enabling servers to leverage AI capabilities—with no server API + * keys necessary. Servers can request text or image-based interactions and optionally + * include context from MCP servers in their prompts. + * @param createMessageRequest The request to create a new message + * @return A Mono that completes when the message has been created + * @throws McpError if the client has not been initialized or does not support + * sampling capabilities + * @throws McpError if the client does not support the createMessage method + * @see McpSchema.CreateMessageRequest + * @see McpSchema.CreateMessageResult + * @see Sampling + * Specification + */ + @Deprecated + public Mono createMessage(McpSchema.CreateMessageRequest createMessageRequest) { + return Mono.error(new RuntimeException("Not implemented")); + } + + /** + * This method is package-private and used for test only. Should not be called by user + * code. + * @param protocolVersions the Client supported protocol versions. + */ + void setProtocolVersions(List protocolVersions) { + this.protocolVersions = protocolVersions; + } + } + + private static final class LegacyAsyncServer extends McpAsyncServer { + /** + * The MCP session implementation that manages bidirectional JSON-RPC communication + * between clients and servers. + */ + private final DefaultMcpSession mcpSession; + + private final ServerMcpTransport transport; + + private final McpSchema.ServerCapabilities serverCapabilities; + + private final McpSchema.Implementation serverInfo; + + private McpSchema.ClientCapabilities clientCapabilities; + + private McpSchema.Implementation clientInfo; + + /** + * Thread-safe list of tool handlers that can be modified at runtime. + */ + private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); + + private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); + + private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); + + private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; + + /** + * Supported protocol versions. + */ + private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); + + /** + * Create a new McpAsyncServer with the given transport and capabilities. + * @param mcpTransport The transport layer implementation for MCP communication. + * @param features The MCP server supported features. + */ + LegacyAsyncServer(ServerMcpTransport mcpTransport, McpServerFeatures.Async features) { + + this.serverInfo = features.serverInfo(); + this.serverCapabilities = features.serverCapabilities(); + this.tools.addAll(features.tools()); + this.resources.putAll(features.resources()); + this.resourceTemplates.addAll(features.resourceTemplates()); + this.prompts.putAll(features.prompts()); + + Map> requestHandlers = new HashMap<>(); + + // Initialize request handlers for standard MCP methods + requestHandlers.put(McpSchema.METHOD_INITIALIZE, asyncInitializeRequestHandler()); + + // Ping MUST respond with an empty data, but not NULL response. + requestHandlers.put(McpSchema.METHOD_PING, (params) -> Mono.just("")); + + // Add tools API handlers if the tool capability is enabled + if (this.serverCapabilities.tools() != null) { + requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); + } + + // Add resources API handlers if provided + if (this.serverCapabilities.resources() != null) { + requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); + } + + // Add prompts API handlers if provider exists + if (this.serverCapabilities.prompts() != null) { + requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); + } + + // Add logging API handlers if the logging capability is enabled + if (this.serverCapabilities.logging() != null) { + requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); + } + + Map notificationHandlers = new HashMap<>(); + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (params) -> Mono.empty()); + + List, Mono>> rootsChangeConsumers = features.rootsChangeConsumers(); + + if (Utils.isEmpty(rootsChangeConsumers)) { + rootsChangeConsumers = List.of((roots) -> Mono.fromRunnable(() -> logger + .warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); + } + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, + asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); + + this.transport = mcpTransport; + this.mcpSession = new DefaultMcpSession(Duration.ofSeconds(10), mcpTransport, requestHandlers, + notificationHandlers); + } + + // --------------------------------------- + // Lifecycle Management + // --------------------------------------- + private DefaultMcpSession.RequestHandler asyncInitializeRequestHandler() { + return params -> { + McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(params, + new TypeReference() { + }); + this.clientCapabilities = initializeRequest.capabilities(); + this.clientInfo = initializeRequest.clientInfo(); + logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", + initializeRequest.protocolVersion(), initializeRequest.capabilities(), + initializeRequest.clientInfo()); + + // The server MUST respond with the highest protocol version it supports if + // it does not support the requested (e.g. Client) version. + String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); + + if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { + // If the server supports the requested protocol version, it MUST respond + // with the same version. + serverProtocolVersion = initializeRequest.protocolVersion(); + } + else { + logger.warn( + "Client requested unsupported protocol version: {}, so the server will sugggest the {} version instead", + initializeRequest.protocolVersion(), serverProtocolVersion); + } + + return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, + this.serverInfo, null)); + }; + } + + /** + * Get the server capabilities that define the supported features and functionality. + * @return The server capabilities + */ + public McpSchema.ServerCapabilities getServerCapabilities() { + return this.serverCapabilities; + } + + /** + * Get the server implementation information. + * @return The server implementation details + */ + public McpSchema.Implementation getServerInfo() { + return this.serverInfo; + } + + /** + * Get the client capabilities that define the supported features and functionality. + * @return The client capabilities + */ + public ClientCapabilities getClientCapabilities() { + return this.clientCapabilities; + } + + /** + * Get the client implementation information. + * @return The client implementation details + */ + public McpSchema.Implementation getClientInfo() { + return this.clientInfo; + } + + /** + * Gracefully closes the server, allowing any in-progress operations to complete. + * @return A Mono that completes when the server has been closed + */ + public Mono closeGracefully() { + return this.mcpSession.closeGracefully(); + } + + /** + * Close the server immediately. + */ + public void close() { + this.mcpSession.close(); + } + + private static final TypeReference LIST_ROOTS_RESULT_TYPE_REF = new TypeReference<>() { + }; + + /** + * Retrieves the list of all roots provided by the client. + * @return A Mono that emits the list of roots result. + */ + public Mono listRoots() { + return this.listRoots(null); + } + + /** + * Retrieves a paginated list of roots provided by the server. + * @param cursor Optional pagination cursor from a previous list request + * @return A Mono that emits the list of roots result containing + */ + public Mono listRoots(String cursor) { + return this.mcpSession.sendRequest(McpSchema.METHOD_ROOTS_LIST, new McpSchema.PaginatedRequest(cursor), + LIST_ROOTS_RESULT_TYPE_REF); + } + + private DefaultMcpSession.NotificationHandler asyncRootsListChangedNotificationHandler( + List, Mono>> rootsChangeConsumers) { + return params -> listRoots().flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) + .flatMap(consumer -> consumer.apply(listRootsResult.roots())) + .onErrorResume(error -> { + logger.error("Error handling roots list change notification", error); + return Mono.empty(); + }) + .then()); + } + + // --------------------------------------- + // Tool Management + // --------------------------------------- + + /** + * Add a new tool registration at runtime. + * @param toolRegistration The tool registration to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistration) { + if (toolRegistration == null) { + return Mono.error(new McpError("Tool registration must not be null")); + } + if (toolRegistration.tool() == null) { + return Mono.error(new McpError("Tool must not be null")); + } + if (toolRegistration.call() == null) { + return Mono.error(new McpError("Tool call handler must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + // Check for duplicate tool names + if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolRegistration.tool().name()))) { + return Mono + .error(new McpError("Tool with name '" + toolRegistration.tool().name() + "' already exists")); + } + + this.tools.add(toolRegistration); + logger.debug("Added tool handler: {}", toolRegistration.tool().name()); + + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + return Mono.empty(); + }); + } + + /** + * Remove a tool handler at runtime. + * @param toolName The name of the tool handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeTool(String toolName) { + if (toolName == null) { + return Mono.error(new McpError("Tool name must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + boolean removed = this.tools.removeIf(toolRegistration -> toolRegistration.tool().name().equals(toolName)); + if (removed) { + logger.debug("Removed tool handler: {}", toolName); + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); + }); + } + + /** + * Notifies clients that the list of available tools has changed. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyToolsListChanged() { + return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); + } + + private DefaultMcpSession.RequestHandler toolsListRequestHandler() { + return params -> { + List tools = this.tools.stream().map(McpServerFeatures.AsyncToolRegistration::tool).toList(); + + return Mono.just(new McpSchema.ListToolsResult(tools, null)); + }; + } + + private DefaultMcpSession.RequestHandler toolsCallRequestHandler() { + return params -> { + McpSchema.CallToolRequest callToolRequest = transport.unmarshalFrom(params, + new TypeReference() { + }); + + Optional toolRegistration = this.tools.stream() + .filter(tr -> callToolRequest.name().equals(tr.tool().name())) + .findAny(); + + if (toolRegistration.isEmpty()) { + return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); + } + + return toolRegistration.map(tool -> tool.call().apply(callToolRequest.arguments())) + .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); + }; + } + + // --------------------------------------- + // Resource Management + // --------------------------------------- + + /** + * Add a new resource handler at runtime. + * @param resourceHandler The resource handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addResource(McpServerFeatures.AsyncResourceRegistration resourceHandler) { + if (resourceHandler == null || resourceHandler.resource() == null) { + return Mono.error(new McpError("Resource must not be null")); + } + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + if (this.resources.putIfAbsent(resourceHandler.resource().uri(), resourceHandler) != null) { + return Mono + .error(new McpError("Resource with URI '" + resourceHandler.resource().uri() + "' already exists")); + } + logger.debug("Added resource handler: {}", resourceHandler.resource().uri()); + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + }); + } + + /** + * Remove a resource handler at runtime. + * @param resourceUri The URI of the resource handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeResource(String resourceUri) { + if (resourceUri == null) { + return Mono.error(new McpError("Resource URI must not be null")); + } + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncResourceRegistration removed = this.resources.remove(resourceUri); + if (removed != null) { + logger.debug("Removed resource handler: {}", resourceUri); + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); + }); + } + + /** + * Notifies clients that the list of available resources has changed. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyResourcesListChanged() { + return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); + } + + private DefaultMcpSession.RequestHandler resourcesListRequestHandler() { + return params -> { + var resourceList = this.resources.values() + .stream() + .map(McpServerFeatures.AsyncResourceRegistration::resource) + .toList(); + return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); + }; + } + + private DefaultMcpSession.RequestHandler resourceTemplateListRequestHandler() { + return params -> Mono.just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null)); + + } + + private DefaultMcpSession.RequestHandler resourcesReadRequestHandler() { + return params -> { + McpSchema.ReadResourceRequest resourceRequest = transport.unmarshalFrom(params, + new TypeReference() { + }); + var resourceUri = resourceRequest.uri(); + McpServerFeatures.AsyncResourceRegistration registration = this.resources.get(resourceUri); + if (registration != null) { + return registration.readHandler().apply(resourceRequest); + } + return Mono.error(new McpError("Resource not found: " + resourceUri)); + }; + } + + // --------------------------------------- + // Prompt Management + // --------------------------------------- + + /** + * Add a new prompt handler at runtime. + * @param promptRegistration The prompt handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegistration) { + if (promptRegistration == null) { + return Mono.error(new McpError("Prompt registration must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new McpError("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncPromptRegistration registration = this.prompts + .putIfAbsent(promptRegistration.prompt().name(), promptRegistration); + if (registration != null) { + return Mono.error( + new McpError("Prompt with name '" + promptRegistration.prompt().name() + "' already exists")); + } + + logger.debug("Added prompt handler: {}", promptRegistration.prompt().name()); + + // Servers that declared the listChanged capability SHOULD send a + // notification, + // when the list of available prompts changes + if (this.serverCapabilities.prompts().listChanged()) { + return notifyPromptsListChanged(); + } + return Mono.empty(); + }); + } + + /** + * Remove a prompt handler at runtime. + * @param promptName The name of the prompt handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removePrompt(String promptName) { + if (promptName == null) { + return Mono.error(new McpError("Prompt name must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new McpError("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncPromptRegistration removed = this.prompts.remove(promptName); + + if (removed != null) { + logger.debug("Removed prompt handler: {}", promptName); + // Servers that declared the listChanged capability SHOULD send a + // notification, when the list of available prompts changes + if (this.serverCapabilities.prompts().listChanged()) { + return this.notifyPromptsListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); + }); + } + + /** + * Notifies clients that the list of available prompts has changed. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyPromptsListChanged() { + return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); + } + + private DefaultMcpSession.RequestHandler promptsListRequestHandler() { + return params -> { + // TODO: Implement pagination + // McpSchema.PaginatedRequest request = transport.unmarshalFrom(params, + // new TypeReference() { + // }); + + var promptList = this.prompts.values() + .stream() + .map(McpServerFeatures.AsyncPromptRegistration::prompt) + .toList(); + + return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); + }; + } + + private DefaultMcpSession.RequestHandler promptsGetRequestHandler() { + return params -> { + McpSchema.GetPromptRequest promptRequest = transport.unmarshalFrom(params, + new TypeReference() { + }); + + // Implement prompt retrieval logic here + McpServerFeatures.AsyncPromptRegistration registration = this.prompts.get(promptRequest.name()); + if (registration == null) { + return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); + } + + return registration.promptHandler().apply(promptRequest); + }; + } + + // --------------------------------------- + // Logging Management + // --------------------------------------- + + /** + * Send a logging message notification to all connected clients. Messages below the + * current minimum logging level will be filtered out. + * @param loggingMessageNotification The logging message to send + * @return A Mono that completes when the notification has been sent + */ + public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { + + if (loggingMessageNotification == null) { + return Mono.error(new McpError("Logging message must not be null")); + } + + Map params = this.transport.unmarshalFrom(loggingMessageNotification, + new TypeReference>() { + }); + + if (loggingMessageNotification.level().level() < minLoggingLevel.level()) { + return Mono.empty(); + } + + return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_MESSAGE, params); + } + + /** + * Handles requests to set the minimum logging level. Messages below this level will + * not be sent. + * @return A handler that processes logging level change requests + */ + private DefaultMcpSession.RequestHandler setLoggerRequestHandler() { + return params -> { + this.minLoggingLevel = transport.unmarshalFrom(params, new TypeReference() { + }); + + return Mono.empty(); + }; + } + + // --------------------------------------- + // Sampling + // --------------------------------------- + private static final TypeReference CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() { + }; + + /** + * Create a new message using the sampling capabilities of the client. The Model + * Context Protocol (MCP) provides a standardized way for servers to request LLM + * sampling (“completions” or “generations”) from language models via clients. This + * flow allows clients to maintain control over model access, selection, and + * permissions while enabling servers to leverage AI capabilities—with no server API + * keys necessary. Servers can request text or image-based interactions and optionally + * include context from MCP servers in their prompts. + * @param createMessageRequest The request to create a new message + * @return A Mono that completes when the message has been created + * @throws McpError if the client has not been initialized or does not support + * sampling capabilities + * @throws McpError if the client does not support the createMessage method + * @see McpSchema.CreateMessageRequest + * @see McpSchema.CreateMessageResult + * @see Sampling + * Specification + */ + public Mono createMessage(McpSchema.CreateMessageRequest createMessageRequest) { + + if (this.clientCapabilities == null) { + return Mono.error(new McpError("Client must be initialized. Call the initialize method first!")); + } + if (this.clientCapabilities.sampling() == null) { + return Mono.error(new McpError("Client must be configured with sampling capabilities")); + } + return this.mcpSession.sendRequest(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, createMessageRequest, + CREATE_MESSAGE_RESULT_TYPE_REF); + } + + /** + * This method is package-private and used for test only. Should not be called by user + * code. + * @param protocolVersions the Client supported protocol versions. + */ + void setProtocolVersions(List protocolVersions) { + this.protocolVersions = protocolVersions; + } + + } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index 54c7a28f..cff897dd 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -11,7 +11,9 @@ import java.util.function.Consumer; import java.util.function.Function; +import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.spec.McpTransport; import io.modelcontextprotocol.spec.ServerMcpTransport; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -132,10 +134,15 @@ static SyncSpec sync(ServerMcpTransport transport) { * @param transport The transport layer implementation for MCP communication * @return A new instance of {@link SyncSpec} for configuring the server. */ + @Deprecated static AsyncSpec async(ServerMcpTransport transport) { return new AsyncSpec(transport); } + static AsyncSpec async(McpServerTransportProvider transportProvider) { + return new AsyncSpec(transportProvider); + } + /** * Asynchronous server specification. */ @@ -145,6 +152,7 @@ class AsyncSpec { "1.0.0"); private final ServerMcpTransport transport; + private final McpServerTransportProvider transportProvider; private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; @@ -181,9 +189,16 @@ class AsyncSpec { private final List, Mono>> rootsChangeConsumers = new ArrayList<>(); + private AsyncSpec(McpServerTransportProvider transportProvider) { + Assert.notNull(transportProvider, "Transport provider must not be null"); + this.transport = null; + this.transportProvider = transportProvider; + } + private AsyncSpec(ServerMcpTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; + this.transportProvider = null; } /** @@ -507,9 +522,15 @@ public AsyncSpec rootsChangeConsumers( * settings */ public McpAsyncServer build() { - return new McpAsyncServer(this.transport, - new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, this.resources, - this.resourceTemplates, this.prompts, this.rootsChangeConsumers)); + var features = new McpServerFeatures.Async(this.serverInfo, + this.serverCapabilities, this.tools, this.resources, + this.resourceTemplates, this.prompts, this.rootsChangeConsumers); + if (this.transportProvider != null) { + // FIXME: provide ObjectMapper configuration + return new McpAsyncServer(this.transportProvider, new ObjectMapper(), features); + } else { + return new McpAsyncServer(this.transport, features); + } } } @@ -523,6 +544,7 @@ class SyncSpec { "1.0.0"); private final ServerMcpTransport transport; + private final McpServerTransportProvider transportProvider; private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; @@ -559,9 +581,16 @@ class SyncSpec { private final List>> rootsChangeConsumers = new ArrayList<>(); + private SyncSpec(McpServerTransportProvider transportProvider) { + Assert.notNull(transportProvider, "Transport provider must not be null"); + this.transportProvider = transportProvider; + this.transport = null; + } + private SyncSpec(ServerMcpTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; + this.transportProvider = null; } /** @@ -620,7 +649,7 @@ public SyncSpec capabilities(McpSchema.ServerCapabilities serverCapabilities) { /** * Adds a single tool with its implementation handler to the server. This is a * convenience method for registering individual tools without creating a - * {@link ToolRegistration} explicitly. + * {@link McpServerFeatures.SyncToolRegistration} explicitly. * *

    * Example usage:

    {@code
    @@ -888,8 +917,12 @@ public SyncSpec rootsChangeConsumers(Consumer>... consumers
     		public McpSyncServer build() {
     			McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities,
     					this.tools, this.resources, this.resourceTemplates, this.prompts, this.rootsChangeConsumers);
    -			return new McpSyncServer(
    -					new McpAsyncServer(this.transport, McpServerFeatures.Async.fromSync(syncFeatures)));
    +			McpServerFeatures.Async asyncFeatures =
    +					McpServerFeatures.Async.fromSync(syncFeatures);
    +			var asyncServer = this.transportProvider != null ? new McpAsyncServer(this.transportProvider, new ObjectMapper(), asyncFeatures)
    +					: new McpAsyncServer(this.transport, asyncFeatures);
    +
    +			return new McpSyncServer(asyncServer);
     		}
     
     	}
    diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java
    index 1de0139b..b214848e 100644
    --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java
    +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java
    @@ -69,6 +69,7 @@ public McpSyncServer(McpAsyncServer asyncServer) {
     	 * Retrieves the list of all roots provided by the client.
     	 * @return The list of roots
     	 */
    +	@Deprecated
     	public McpSchema.ListRootsResult listRoots() {
     		return this.listRoots(null);
     	}
    @@ -78,6 +79,7 @@ public McpSchema.ListRootsResult listRoots() {
     	 * @param cursor Optional pagination cursor from a previous list request
     	 * @return The list of roots
     	 */
    +	@Deprecated
     	public McpSchema.ListRootsResult listRoots(String cursor) {
     		return this.asyncServer.listRoots(cursor).block();
     	}
    @@ -157,6 +159,7 @@ public McpSchema.Implementation getServerInfo() {
     	 * Get the client capabilities that define the supported features and functionality.
     	 * @return The client capabilities
     	 */
    +	@Deprecated
     	public ClientCapabilities getClientCapabilities() {
     		return this.asyncServer.getClientCapabilities();
     	}
    @@ -165,6 +168,7 @@ public ClientCapabilities getClientCapabilities() {
     	 * Get the client implementation information.
     	 * @return The client implementation details
     	 */
    +	@Deprecated
     	public McpSchema.Implementation getClientInfo() {
     		return this.asyncServer.getClientInfo();
     	}
    @@ -238,6 +242,7 @@ public McpAsyncServer getAsyncServer() {
     	 * "https://spec.modelcontextprotocol.io/specification/client/sampling/">Sampling
     	 * Specification
     	 */
    +	@Deprecated
     	public McpSchema.CreateMessageResult createMessage(McpSchema.CreateMessageRequest createMessageRequest) {
     		return this.asyncServer.createMessage(createMessageRequest).block();
     	}
    diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java
    index 24767f1f..702f01d6 100644
    --- a/mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java
    +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java
    @@ -11,9 +11,9 @@
      * Marker interface for the client-side MCP transport.
      *
      * @author Christian Tzolov
    + * @deprecated This class will be removed in 0.9.0. Use {@link McpClientTransport}.
      */
    +@Deprecated
     public interface ClientMcpTransport extends McpTransport {
     
    -	Mono connect(Function, Mono> handler);
    -
     }
    diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java
    new file mode 100644
    index 00000000..fa90e96f
    --- /dev/null
    +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java
    @@ -0,0 +1,12 @@
    +package io.modelcontextprotocol.spec;
    +
    +import java.util.function.Function;
    +
    +import reactor.core.publisher.Mono;
    +
    +public interface McpClientTransport extends McpTransport {
    +
    +	@Override
    +	Mono connect(Function, Mono> handler);
    +
    +}
    diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java
    new file mode 100644
    index 00000000..ef5f5c6f
    --- /dev/null
    +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java
    @@ -0,0 +1,5 @@
    +package io.modelcontextprotocol.spec;
    +
    +public interface McpServerTransport extends McpTransport {
    +
    +}
    diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java
    new file mode 100644
    index 00000000..f7208c4d
    --- /dev/null
    +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java
    @@ -0,0 +1,32 @@
    +package io.modelcontextprotocol.spec;
    +
    +import java.util.Map;
    +
    +import reactor.core.publisher.Mono;
    +
    +public interface McpServerTransportProvider {
    +
    +	// TODO: Consider adding a ProviderFactory that gets the Session Factory
    +	void setSessionFactory(ServerMcpSession.Factory sessionFactory);
    +
    +	Mono notifyClients(String method, Map params);
    +
    +	/**
    +	 * Closes the transport connection and releases any associated resources.
    +	 *
    +	 * 

    + * This method ensures proper cleanup of resources when the transport is no longer + * needed. It should handle the graceful shutdown of any active connections. + *

    + */ + default void close() { + this.closeGracefully().subscribe(); + } + + /** + * Closes the transport connection and releases any associated resources + * asynchronously. + * @return a {@link Mono} that completes when the connection has been closed. + */ + Mono closeGracefully(); +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java index 886f4be0..f698d878 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java @@ -47,10 +47,12 @@ public interface McpTransport { * necessary resources and establishes the connection to the server. *

    * @deprecated This is only relevant for client-side transports and will be removed - * from this interface. + * from this interface in 0.9.0. */ @Deprecated - Mono connect(Function, Mono> handler); + default Mono connect(Function, Mono> handler) { + return Mono.empty(); + } /** * Closes the transport connection and releases any associated resources. diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpExchange.java b/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpExchange.java index 4facdb1c..86d4175a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpExchange.java @@ -5,6 +5,10 @@ public class ServerMcpExchange { + // map(roots) + // map(resource_subscription) + // initialization state + private final ServerMcpSession session; private final McpSchema.ClientCapabilities clientCapabilities; diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpSession.java index 8265343a..07d48693 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpSession.java @@ -2,9 +2,10 @@ import java.time.Duration; import java.util.Map; -import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; import com.fasterxml.jackson.core.type.TypeReference; import org.slf4j.Logger; @@ -19,37 +20,55 @@ public class ServerMcpSession implements McpSession { private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); - private final String sessionPrefix = UUID.randomUUID().toString().substring(0, 8); + private final String id; private final AtomicLong requestCounter = new AtomicLong(0); - private final InitHandler initHandler; + private final InitRequestHandler initRequestHandler; + + private final InitNotificationHandler initNotificationHandler; private final Map> requestHandlers; private final Map notificationHandlers; - // TODO: used only to unmarshall - could be extracted to another interface - private final McpTransport transport; + private final McpServerTransport transport; private final Sinks.One exchangeSink = Sinks.one(); - - volatile boolean isInitialized = false; - - public ServerMcpSession(McpTransport transport, InitHandler initHandler, - Map> requestHandlers, Map notificationHandlers) { + private final AtomicReference clientCapabilities = new AtomicReference<>(); + private final AtomicReference clientInfo = new AtomicReference<>(); + + // 0 = uninitialized, 1 = initializing, 2 = initialized + private static final int UNINITIALIZED = 0; + private static final int INITIALIZING = 1; + private static final int INITIALIZED = 2; + + private final AtomicInteger state = new AtomicInteger(UNINITIALIZED); + + public ServerMcpSession(String id, McpServerTransport transport, + InitRequestHandler initHandler, + InitNotificationHandler initNotificationHandler, + Map> requestHandlers, + Map notificationHandlers) { + this.id = id; this.transport = transport; - this.initHandler = initHandler; + this.initRequestHandler = initHandler; + this.initNotificationHandler = initNotificationHandler; this.requestHandlers = requestHandlers; this.notificationHandlers = notificationHandlers; } + public String getId() { + return this.id; + } + public void init(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { - exchangeSink.tryEmitValue(new ServerMcpExchange(this, clientCapabilities, clientInfo)); + this.clientCapabilities.lazySet(clientCapabilities); + this.clientInfo.lazySet(clientInfo); } private String generateRequestId() { - return this.sessionPrefix + "-" + this.requestCounter.getAndIncrement(); + return this.id + "-" + this.requestCounter.getAndIncrement(); } public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { @@ -140,12 +159,14 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR transport.unmarshalFrom(request.params(), new TypeReference() { }); - resultMono = this.initHandler.handle(new ClientInitConsumer(), initializeRequest) - .doOnNext(initResult -> this.isInitialized = true); + + this.state.lazySet(INITIALIZING); + this.init(initializeRequest.capabilities(), initializeRequest.clientInfo()); + resultMono = this.initRequestHandler.handle(initializeRequest); } else { - // TODO handle errors for communication to without initialization - // happening first + // TODO handle errors for communication to this session without + // initialization happening first var handler = this.requestHandlers.get(request.method()); if (handler == null) { MethodNotFoundError error = getMethodNotFoundError(request.method()); @@ -172,12 +193,20 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR */ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification) { return Mono.defer(() -> { + if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) { + this.state.lazySet(INITIALIZED); + exchangeSink.tryEmitValue(new ServerMcpExchange(this, clientCapabilities.get(), clientInfo.get())); + return this.initNotificationHandler.handle(); + } + var handler = notificationHandlers.get(notification.method()); if (handler == null) { logger.error("No handler registered for notification method: {}", notification.method()); return Mono.empty(); } - return handler.handle(this, notification.params()); + return this.exchangeSink.asMono() + .flatMap(exchange -> + handler.handle(exchange, notification.params())); }); } @@ -204,36 +233,25 @@ public void close() { this.transport.close(); } - public class ClientInitConsumer { - - public void init(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { - ServerMcpSession.this.init(clientCapabilities, clientInfo); - } - + public interface InitRequestHandler { + Mono handle(McpSchema.InitializeRequest initializeRequest); } - public interface InitHandler { - - Mono handle(ClientInitConsumer clientInitConsumer, - McpSchema.InitializeRequest initializeRequest); - + public interface InitNotificationHandler { + Mono handle(); } public interface NotificationHandler { - - Mono handle(ServerMcpSession connection, Object params); - + Mono handle(ServerMcpExchange exchange, Object params); } public interface RequestHandler { - Mono handle(ServerMcpExchange exchange, Object params); - } @FunctionalInterface public interface Factory { - ServerMcpSession create(ServerMcpTransport.Child sessionTransport); + ServerMcpSession create(McpServerTransport sessionTransport); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java index 6c3442d1..704daee0 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java @@ -3,28 +3,13 @@ */ package io.modelcontextprotocol.spec; -import java.util.function.Function; - -import reactor.core.publisher.Mono; - /** * Marker interface for the server-side MCP transport. * * @author Christian Tzolov + * @deprecated This class will be removed in 0.9.0. Use {@link McpServerTransport}. */ +@Deprecated public interface ServerMcpTransport extends McpTransport { - @Override - default Mono connect(Function, Mono> handler) { - throw new IllegalStateException("Server transport does not support connect method"); - } - - void setSessionFactory(ServerMcpSession.Factory sessionFactory); - - interface Child extends McpTransport { - @Override - default Mono connect(Function, Mono> handler) { - throw new IllegalStateException("Server transport does not support connect method"); - } - } } From 4c784559a77ed96e4012547f42903862a320a7a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Thu, 13 Mar 2025 09:05:31 +0100 Subject: [PATCH 04/20] Rename server session and exchange to follow consistent pattern MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../WebFluxSseServerTransportProvider.java | 14 +++++----- .../server/McpAsyncServer.java | 26 +++++++++---------- ...cpExchange.java => McpServerExchange.java} | 6 ++--- ...rMcpSession.java => McpServerSession.java} | 16 ++++++------ .../spec/McpServerTransportProvider.java | 2 +- 5 files changed, 32 insertions(+), 32 deletions(-) rename mcp/src/main/java/io/modelcontextprotocol/spec/{ServerMcpExchange.java => McpServerExchange.java} (95%) rename mcp/src/main/java/io/modelcontextprotocol/spec/{ServerMcpSession.java => McpServerSession.java} (94%) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 732616dc..e17135ac 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -10,7 +10,7 @@ import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; -import io.modelcontextprotocol.spec.ServerMcpSession; +import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.ServerMcpTransport; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; @@ -91,12 +91,12 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv private final RouterFunction routerFunction; - private ServerMcpSession.Factory sessionFactory; + private McpServerSession.Factory sessionFactory; /** * Map of active client sessions, keyed by session ID. */ - private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); /** * Flag indicating if the transport is shutting down. @@ -141,7 +141,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa } @Override - public void setSessionFactory(ServerMcpSession.Factory sessionFactory) { + public void setSessionFactory(McpServerSession.Factory sessionFactory) { this.sessionFactory = sessionFactory; } @@ -199,7 +199,7 @@ public Mono notifyClients(String method, Map params) { public Mono closeGracefully() { return Flux.fromIterable(sessions.values()) .doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size())) - .flatMap(ServerMcpSession::closeGracefully) + .flatMap(McpServerSession::closeGracefully) .then(); } @@ -245,7 +245,7 @@ private Mono handleSseConnection(ServerRequest request) { .body(Flux.>create(sink -> { WebFluxMcpSessionTransport sessionTransport = new WebFluxMcpSessionTransport(sink); - ServerMcpSession session = sessionFactory.create(sessionTransport); + McpServerSession session = sessionFactory.create(sessionTransport); String sessionId = session.getId(); logger.debug("Created new SSE connection for session: {}", sessionId); @@ -288,7 +288,7 @@ private Mono handleMessage(ServerRequest request) { return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing in message endpoint")); } - ServerMcpSession session = sessions.get(request.queryParam("sessionId").get()); + McpServerSession session = sessions.get(request.queryParam("sessionId").get()); return request.bodyToMono(String.class).flatMap(body -> { try { diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index d565cb9e..91862821 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -20,7 +20,7 @@ import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerTransportProvider; -import io.modelcontextprotocol.spec.ServerMcpSession; +import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; @@ -362,7 +362,7 @@ private static class AsyncServerImpl extends McpAsyncServer { this.resourceTemplates.addAll(features.resourceTemplates()); this.prompts.putAll(features.prompts()); - Map> requestHandlers = new HashMap<>(); + Map> requestHandlers = new HashMap<>(); // Initialize request handlers for standard MCP methods @@ -393,7 +393,7 @@ private static class AsyncServerImpl extends McpAsyncServer { requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); } - Map notificationHandlers = new HashMap<>(); + Map notificationHandlers = new HashMap<>(); notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); @@ -407,7 +407,7 @@ private static class AsyncServerImpl extends McpAsyncServer { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - mcpTransportProvider.setSessionFactory(transport -> new ServerMcpSession( + mcpTransportProvider.setSessionFactory(transport -> new McpServerSession( UUID.randomUUID().toString(), transport, this::asyncInitializeRequestHandler, @@ -513,7 +513,7 @@ public Mono listRoots(String cursor) { return Mono.error(new RuntimeException("Not implemented")); } - private ServerMcpSession.NotificationHandler asyncRootsListChangedNotificationHandler( + private McpServerSession.NotificationHandler asyncRootsListChangedNotificationHandler( List, Mono>> rootsChangeConsumers) { return (exchange, params) -> listRoots().flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) @@ -599,7 +599,7 @@ public Mono notifyToolsListChanged() { return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); } - private ServerMcpSession.RequestHandler toolsListRequestHandler() { + private McpServerSession.RequestHandler toolsListRequestHandler() { return (exchange, params) -> { List tools = this.tools.stream().map(McpServerFeatures.AsyncToolRegistration::tool).toList(); @@ -607,7 +607,7 @@ private ServerMcpSession.RequestHandler toolsListRequ }; } - private ServerMcpSession.RequestHandler toolsCallRequestHandler() { + private McpServerSession.RequestHandler toolsCallRequestHandler() { return (exchange, params) -> { McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, new TypeReference() { @@ -693,7 +693,7 @@ public Mono notifyResourcesListChanged() { return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); } - private ServerMcpSession.RequestHandler resourcesListRequestHandler() { + private McpServerSession.RequestHandler resourcesListRequestHandler() { return (exchange, params) -> { var resourceList = this.resources.values() .stream() @@ -703,12 +703,12 @@ private ServerMcpSession.RequestHandler resources }; } - private ServerMcpSession.RequestHandler resourceTemplateListRequestHandler() { + private McpServerSession.RequestHandler resourceTemplateListRequestHandler() { return (exchange, params) -> Mono.just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null)); } - private ServerMcpSession.RequestHandler resourcesReadRequestHandler() { + private McpServerSession.RequestHandler resourcesReadRequestHandler() { return (exchange, params) -> { McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, new TypeReference() { @@ -796,7 +796,7 @@ public Mono notifyPromptsListChanged() { return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); } - private ServerMcpSession.RequestHandler promptsListRequestHandler() { + private McpServerSession.RequestHandler promptsListRequestHandler() { return (exchange, params) -> { // TODO: Implement pagination // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, @@ -812,7 +812,7 @@ private ServerMcpSession.RequestHandler promptsList }; } - private ServerMcpSession.RequestHandler promptsGetRequestHandler() { + private McpServerSession.RequestHandler promptsGetRequestHandler() { return (exchange, params) -> { McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, new TypeReference() { @@ -860,7 +860,7 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN * not be sent. * @return A handler that processes logging level change requests */ - private ServerMcpSession.RequestHandler setLoggerRequestHandler() { + private McpServerSession.RequestHandler setLoggerRequestHandler() { return (exchange, params) -> { this.minLoggingLevel = objectMapper.convertValue(params, new TypeReference() { }); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpExchange.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerExchange.java similarity index 95% rename from mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpExchange.java rename to mcp/src/main/java/io/modelcontextprotocol/spec/McpServerExchange.java index 86d4175a..a8f54a2d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerExchange.java @@ -3,19 +3,19 @@ import com.fasterxml.jackson.core.type.TypeReference; import reactor.core.publisher.Mono; -public class ServerMcpExchange { +public class McpServerExchange { // map(roots) // map(resource_subscription) // initialization state - private final ServerMcpSession session; + private final McpServerSession session; private final McpSchema.ClientCapabilities clientCapabilities; private final McpSchema.Implementation clientInfo; - public ServerMcpExchange(ServerMcpSession session, McpSchema.ClientCapabilities clientCapabilities, + public McpServerExchange(McpServerSession session, McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { this.session = session; this.clientCapabilities = clientCapabilities; diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java similarity index 94% rename from mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpSession.java rename to mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 07d48693..0edd20b6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -14,9 +14,9 @@ import reactor.core.publisher.MonoSink; import reactor.core.publisher.Sinks; -public class ServerMcpSession implements McpSession { +public class McpServerSession implements McpSession { - private static final Logger logger = LoggerFactory.getLogger(ServerMcpSession.class); + private static final Logger logger = LoggerFactory.getLogger(McpServerSession.class); private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); @@ -34,7 +34,7 @@ public class ServerMcpSession implements McpSession { private final McpServerTransport transport; - private final Sinks.One exchangeSink = Sinks.one(); + private final Sinks.One exchangeSink = Sinks.one(); private final AtomicReference clientCapabilities = new AtomicReference<>(); private final AtomicReference clientInfo = new AtomicReference<>(); @@ -45,7 +45,7 @@ public class ServerMcpSession implements McpSession { private final AtomicInteger state = new AtomicInteger(UNINITIALIZED); - public ServerMcpSession(String id, McpServerTransport transport, + public McpServerSession(String id, McpServerTransport transport, InitRequestHandler initHandler, InitNotificationHandler initNotificationHandler, Map> requestHandlers, @@ -195,7 +195,7 @@ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification noti return Mono.defer(() -> { if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) { this.state.lazySet(INITIALIZED); - exchangeSink.tryEmitValue(new ServerMcpExchange(this, clientCapabilities.get(), clientInfo.get())); + exchangeSink.tryEmitValue(new McpServerExchange(this, clientCapabilities.get(), clientInfo.get())); return this.initNotificationHandler.handle(); } @@ -242,16 +242,16 @@ public interface InitNotificationHandler { } public interface NotificationHandler { - Mono handle(ServerMcpExchange exchange, Object params); + Mono handle(McpServerExchange exchange, Object params); } public interface RequestHandler { - Mono handle(ServerMcpExchange exchange, Object params); + Mono handle(McpServerExchange exchange, Object params); } @FunctionalInterface public interface Factory { - ServerMcpSession create(McpServerTransport sessionTransport); + McpServerSession create(McpServerTransport sessionTransport); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java index f7208c4d..77ecc043 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java @@ -7,7 +7,7 @@ public interface McpServerTransportProvider { // TODO: Consider adding a ProviderFactory that gets the Session Factory - void setSessionFactory(ServerMcpSession.Factory sessionFactory); + void setSessionFactory(McpServerSession.Factory sessionFactory); Mono notifyClients(String method, Map params); From 15e432fb43bed6cbb7da243b0722a546d47d25f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Thu, 13 Mar 2025 09:46:59 +0100 Subject: [PATCH 05/20] Restore WebFluxSseServerTransport in order to deprecate it MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../transport/WebFluxSseServerTransport.java | 413 ++++++++++++++++++ .../WebFluxSseServerTransportProvider.java | 68 ++- .../WebFluxSseIntegrationTests.java | 6 +- .../server/WebFluxSseMcpAsyncServerTests.java | 6 +- .../server/WebFluxSseMcpSyncServerTests.java | 8 +- .../server/AbstractMcpAsyncServerTests.java | 3 +- .../server/McpAsyncServer.java | 206 +++++---- .../server/McpServer.java | 16 +- .../spec/McpServerSession.java | 44 +- .../spec/McpServerTransportProvider.java | 1 + .../server/AbstractMcpAsyncServerTests.java | 3 +- 11 files changed, 603 insertions(+), 171 deletions(-) create mode 100644 mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java new file mode 100644 index 00000000..fb0b581e --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java @@ -0,0 +1,413 @@ +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.time.Duration; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; + +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.reactive.function.server.ServerResponse; + +/** + * Server-side implementation of the MCP (Model Context Protocol) HTTP transport using + * Server-Sent Events (SSE). This implementation provides a bidirectional communication + * channel between MCP clients and servers using HTTP POST for client-to-server messages + * and SSE for server-to-client messages. + * + *

    + * Key features: + *

      + *
    • Implements the {@link ServerMcpTransport} interface for MCP server transport + * functionality
    • + *
    • Uses WebFlux for non-blocking request handling and SSE support
    • + *
    • Maintains client sessions for reliable message delivery
    • + *
    • Supports graceful shutdown with session cleanup
    • + *
    • Thread-safe message broadcasting to multiple clients
    • + *
    + * + *

    + * The transport sets up two main endpoints: + *

      + *
    • SSE endpoint (/sse) - For establishing SSE connections with clients
    • + *
    • Message endpoint (configurable) - For receiving JSON-RPC messages from clients
    • + *
    + * + *

    + * This implementation is thread-safe and can handle multiple concurrent client + * connections. It uses {@link ConcurrentHashMap} for session management and Reactor's + * {@link Sinks} for thread-safe message broadcasting. + * + * @author Christian Tzolov + * @author Alexandros Pappas + * @see ServerMcpTransport + * @see ServerSentEvent + * @deprecated This class will be removed in 0.9.0. Use + * {@link WebFluxSseServerTransportProvider}. + */ +@Deprecated +public class WebFluxSseServerTransport implements ServerMcpTransport { + + private static final Logger logger = LoggerFactory.getLogger(WebFluxSseServerTransport.class); + + /** + * Event type for JSON-RPC messages sent through the SSE connection. + */ + public static final String MESSAGE_EVENT_TYPE = "message"; + + /** + * Event type for sending the message endpoint URI to clients. + */ + public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + /** + * Default SSE endpoint path as specified by the MCP transport specification. + */ + public static final String DEFAULT_SSE_ENDPOINT = "/sse"; + + private final ObjectMapper objectMapper; + + private final String messageEndpoint; + + private final String sseEndpoint; + + private final RouterFunction routerFunction; + + /** + * Map of active client sessions, keyed by session ID. + */ + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + /** + * Flag indicating if the transport is shutting down. + */ + private volatile boolean isClosing = false; + + private Function, Mono> connectHandler; + + /** + * Constructs a new WebFlux SSE server transport instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of MCP messages. Must not be null. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages. This endpoint will be communicated to clients during SSE connection + * setup. Must not be null. + * @throws IllegalArgumentException if either parameter is null + */ + public WebFluxSseServerTransport(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); + + this.objectMapper = objectMapper; + this.messageEndpoint = messageEndpoint; + this.sseEndpoint = sseEndpoint; + this.routerFunction = RouterFunctions.route() + .GET(this.sseEndpoint, this::handleSseConnection) + .POST(this.messageEndpoint, this::handleMessage) + .build(); + } + + /** + * Constructs a new WebFlux SSE server transport instance with the default SSE + * endpoint. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of MCP messages. Must not be null. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages. This endpoint will be communicated to clients during SSE connection + * setup. Must not be null. + * @throws IllegalArgumentException if either parameter is null + */ + public WebFluxSseServerTransport(ObjectMapper objectMapper, String messageEndpoint) { + this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); + } + + /** + * Configures the message handler for this transport. In the WebFlux SSE + * implementation, this method stores the handler for processing incoming messages but + * doesn't establish any connections since the server accepts connections rather than + * initiating them. + * @param handler A function that processes incoming JSON-RPC messages and returns + * responses. This handler will be called for each message received through the + * message endpoint. + * @return An empty Mono since the server doesn't initiate connections + */ + @Override + public Mono connect(Function, Mono> handler) { + this.connectHandler = handler; + // Server-side transport doesn't initiate connections + return Mono.empty().then(); + } + + /** + * Broadcasts a JSON-RPC message to all connected clients through their SSE + * connections. The message is serialized to JSON and sent as a server-sent event to + * each active session. + * + *

    + * The method: + *

      + *
    • Serializes the message to JSON
    • + *
    • Creates a server-sent event with the message data
    • + *
    • Attempts to send the event to all active sessions
    • + *
    • Tracks and reports any delivery failures
    • + *
    + * @param message The JSON-RPC message to broadcast + * @return A Mono that completes when the message has been sent to all sessions, or + * errors if any session fails to receive the message + */ + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + if (sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + return Mono.create(sink -> { + try {// @formatter:off + String jsonText = objectMapper.writeValueAsString(message); + ServerSentEvent event = ServerSentEvent.builder() + .event(MESSAGE_EVENT_TYPE) + .data(jsonText) + .build(); + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + + List failedSessions = sessions.values().stream() + .filter(session -> session.messageSink.tryEmitNext(event).isFailure()) + .map(session -> session.id) + .toList(); + + if (failedSessions.isEmpty()) { + logger.debug("Successfully broadcast message to all sessions"); + sink.success(); + } + else { + String error = "Failed to broadcast message to sessions: " + String.join(", ", failedSessions); + logger.error(error); + sink.error(new RuntimeException(error)); + } // @formatter:on + } + catch (IOException e) { + logger.error("Failed to serialize message: {}", e.getMessage()); + sink.error(e); + } + }); + } + + /** + * Converts data from one type to another using the configured ObjectMapper. This + * method is primarily used for converting between different representations of + * JSON-RPC message data. + * @param The target type to convert to + * @param data The source data to convert + * @param typeRef Type reference describing the target type + * @return The converted data + * @throws IllegalArgumentException if the conversion fails + */ + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return this.objectMapper.convertValue(data, typeRef); + } + + /** + * Initiates a graceful shutdown of the transport. This method ensures all active + * sessions are properly closed and cleaned up. + * + *

    + * The shutdown process: + *

      + *
    • Marks the transport as closing to prevent new connections
    • + *
    • Closes each active session
    • + *
    • Removes closed sessions from the sessions map
    • + *
    • Times out after 5 seconds if shutdown takes too long
    • + *
    + * @return A Mono that completes when all sessions have been closed + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + isClosing = true; + logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); + }).then(Mono.when(sessions.values().stream().map(session -> { + String sessionId = session.id; + return Mono.fromRunnable(() -> session.close()) + .then(Mono.delay(Duration.ofMillis(100))) + .then(Mono.fromRunnable(() -> sessions.remove(sessionId))); + }).toList())) + .timeout(Duration.ofSeconds(5)) + .doOnSuccess(v -> logger.debug("Graceful shutdown completed")) + .doOnError(e -> logger.error("Error during graceful shutdown: {}", e.getMessage())); + } + + /** + * Returns the WebFlux router function that defines the transport's HTTP endpoints. + * This router function should be integrated into the application's web configuration. + * + *

    + * The router function defines two endpoints: + *

      + *
    • GET {sseEndpoint} - For establishing SSE connections
    • + *
    • POST {messageEndpoint} - For receiving client messages
    • + *
    + * @return The configured {@link RouterFunction} for handling HTTP requests + */ + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + /** + * Handles new SSE connection requests from clients. Creates a new session for each + * connection and sets up the SSE event stream. + * + *

    + * The handler performs the following steps: + *

      + *
    • Generates a unique session ID
    • + *
    • Creates a new ClientSession instance
    • + *
    • Sends the message endpoint URI as an initial event
    • + *
    • Sets up message forwarding for the session
    • + *
    • Handles connection cleanup on completion or errors
    • + *
    + * @param request The incoming server request + * @return A response with the SSE event stream + */ + private Mono handleSseConnection(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + String sessionId = UUID.randomUUID().toString(); + logger.debug("Creating new SSE connection for session: {}", sessionId); + ClientSession session = new ClientSession(sessionId); + this.sessions.put(sessionId, session); + + return ServerResponse.ok() + .contentType(MediaType.TEXT_EVENT_STREAM) + .body(Flux.>create(sink -> { + // Send initial endpoint event + logger.debug("Sending initial endpoint event to session: {}", sessionId); + sink.next(ServerSentEvent.builder().event(ENDPOINT_EVENT_TYPE).data(messageEndpoint).build()); + + // Subscribe to session messages + session.messageSink.asFlux() + .doOnSubscribe(s -> logger.debug("Session {} subscribed to message sink", sessionId)) + .doOnComplete(() -> { + logger.debug("Session {} completed", sessionId); + sessions.remove(sessionId); + }) + .doOnError(error -> { + logger.error("Error in session {}: {}", sessionId, error.getMessage()); + sessions.remove(sessionId); + }) + .doOnCancel(() -> { + logger.debug("Session {} cancelled", sessionId); + sessions.remove(sessionId); + }) + .subscribe(event -> { + logger.debug("Forwarding event to session {}: {}", sessionId, event); + sink.next(event); + }, sink::error, sink::complete); + + sink.onCancel(() -> { + logger.debug("Session {} cancelled", sessionId); + sessions.remove(sessionId); + }); + }), ServerSentEvent.class); + } + + /** + * Handles incoming JSON-RPC messages from clients. Deserializes the message and + * processes it through the configured message handler. + * + *

    + * The handler: + *

      + *
    • Deserializes the incoming JSON-RPC message
    • + *
    • Passes it through the message handler chain
    • + *
    • Returns appropriate HTTP responses based on processing results
    • + *
    • Handles various error conditions with appropriate error responses
    • + *
    + * @param request The incoming server request containing the JSON-RPC message + * @return A response indicating the message processing result + */ + private Mono handleMessage(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + return request.bodyToMono(String.class).flatMap(body -> { + try { + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); + return Mono.just(message) + .transform(this.connectHandler) + .flatMap(response -> ServerResponse.ok().build()) + .onErrorResume(error -> { + logger.error("Error processing message: {}", error.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .bodyValue(new McpError(error.getMessage())); + }); + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format")); + } + }); + } + + /** + * Represents an active client SSE connection session. Manages the message sink for + * sending events to the client and handles session lifecycle. + * + *

    + * Each session: + *

      + *
    • Has a unique identifier
    • + *
    • Maintains its own message sink for event broadcasting
    • + *
    • Supports clean shutdown through the close method
    • + *
    + */ + private static class ClientSession { + + private final String id; + + private final Sinks.Many> messageSink; + + ClientSession(String id) { + this.id = id; + logger.debug("Creating new session: {}", id); + this.messageSink = Sinks.many().replay().latest(); + logger.debug("Session {} initialized with replay sink", id); + } + + void close() { + logger.debug("Closing session: {}", id); + Sinks.EmitResult result = messageSink.tryEmitComplete(); + if (result.isFailure()) { + logger.warn("Failed to complete message sink for session {}: {}", id, result); + } + else { + logger.debug("Successfully completed message sink for session {}", id); + } + } + + } + +} \ No newline at end of file diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index e17135ac..9138d6f4 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -65,8 +65,7 @@ */ public class WebFluxSseServerTransportProvider implements McpServerTransportProvider { - private static final Logger logger = LoggerFactory.getLogger( - WebFluxSseServerTransportProvider.class); + private static final Logger logger = LoggerFactory.getLogger(WebFluxSseServerTransportProvider.class); /** * Event type for JSON-RPC messages sent through the SSE connection. @@ -174,8 +173,7 @@ public Mono notifyClients(String method, Map params) { return Flux.fromStream(sessions.values().stream()) .flatMap(session -> session.sendNotification(method, params) - .doOnError(e -> logger.error("Failed to " + "send message to session " + - "{}: {}", session.getId(), + .doOnError(e -> logger.error("Failed to " + "send message to session " + "{}: {}", session.getId(), e.getMessage())) .onErrorComplete()) .then(); @@ -307,39 +305,35 @@ private Mono handleMessage(ServerRequest request) { } /* - Current: - - framework layer: - var transport = new WebFluxSseServerTransport(objectMapper, "/mcp", "/sse"); - McpServer.async(ServerMcpTransport transport) - - client connects -> - WebFluxSseServerTransport creates a: - - var sessionTransport = WebFluxMcpSessionTransport - - ServerMcpSession(sessionId, sessionTransport) - - WebFluxSseServerTransport IS_A ServerMcpTransport IS_A McpTransport - WebFluxMcpSessionTransport IS_A ServerMcpSessionTransport IS_A McpTransport - - McpTransport contains connect() which should be removed - ClientMcpTransport should have connect() - ServerMcpTransport should have setSessionFactory() - - Possible Future: - var transportProvider = new WebFluxSseServerTransport(objectMapper, "/mcp", "/sse"); - WebFluxSseServerTransport IS_A ServerMcpTransportProvider ? - ServerMcpTransportProvider creates ServerMcpTransport - - // disadvantage - too much breaks, e.g. - McpServer.async(ServerMcpTransportProvider transportProvider) - - // advantage - - ClientMcpTransport and ServerMcpTransport BOTH represent 1:1 relationship - - - - + * Current: + * + * framework layer: var transport = new WebFluxSseServerTransport(objectMapper, + * "/mcp", "/sse"); McpServer.async(ServerMcpTransport transport) + * + * client connects -> WebFluxSseServerTransport creates a: - var sessionTransport = + * WebFluxMcpSessionTransport - ServerMcpSession(sessionId, sessionTransport) + * + * WebFluxSseServerTransport IS_A ServerMcpTransport IS_A McpTransport + * WebFluxMcpSessionTransport IS_A ServerMcpSessionTransport IS_A McpTransport + * + * McpTransport contains connect() which should be removed ClientMcpTransport should + * have connect() ServerMcpTransport should have setSessionFactory() + * + * Possible Future: var transportProvider = new + * WebFluxSseServerTransport(objectMapper, "/mcp", "/sse"); WebFluxSseServerTransport + * IS_A ServerMcpTransportProvider ? ServerMcpTransportProvider creates + * ServerMcpTransport + * + * // disadvantage - too much breaks, e.g. McpServer.async(ServerMcpTransportProvider + * transportProvider) + * + * // advantage + * + * ClientMcpTransport and ServerMcpTransport BOTH represent 1:1 relationship + * + * + * + * */ private class WebFluxMcpSessionTransport implements McpServerTransport { diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 3df80db8..4cd24c62 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -16,7 +16,7 @@ import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServerFeatures; -import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -55,14 +55,14 @@ public class WebFluxSseIntegrationTests { private DisposableServer httpServer; - private WebFluxSseServerTransportProvider mcpServerTransport; + private WebFluxSseServerTransport mcpServerTransport; ConcurrentHashMap clientBulders = new ConcurrentHashMap<>(); @BeforeEach public void before() { - this.mcpServerTransport = new WebFluxSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + this.mcpServerTransport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransport.getRouterFunction()); ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java index 34f4b689..1ed0d99b 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java @@ -5,7 +5,7 @@ package io.modelcontextprotocol.server; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; import io.modelcontextprotocol.spec.ServerMcpTransport; import org.junit.jupiter.api.Timeout; import reactor.netty.DisposableServer; @@ -16,7 +16,7 @@ import org.springframework.web.reactive.function.server.RouterFunctions; /** - * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransportProvider}. + * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransport}. * * @author Christian Tzolov */ @@ -31,7 +31,7 @@ class WebFluxSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { @Override protected ServerMcpTransport createMcpTransport() { - var transport = new WebFluxSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + var transport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); HttpHandler httpHandler = RouterFunctions.toHttpHandler(transport.getRouterFunction()); ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java index 2cf1087d..4db00dd4 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java @@ -5,7 +5,7 @@ package io.modelcontextprotocol.server; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; import io.modelcontextprotocol.spec.ServerMcpTransport; import org.junit.jupiter.api.Timeout; import reactor.netty.DisposableServer; @@ -16,7 +16,7 @@ import org.springframework.web.reactive.function.server.RouterFunctions; /** - * Tests for {@link McpSyncServer} using {@link WebFluxSseServerTransportProvider}. + * Tests for {@link McpSyncServer} using {@link WebFluxSseServerTransport}. * * @author Christian Tzolov */ @@ -29,11 +29,11 @@ class WebFluxSseMcpSyncServerTests extends AbstractMcpSyncServerTests { private DisposableServer httpServer; - private WebFluxSseServerTransportProvider transport; + private WebFluxSseServerTransport transport; @Override protected ServerMcpTransport createMcpTransport() { - transport = new WebFluxSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + transport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); return transport; } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index ca5783d0..725a2167 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -66,7 +66,8 @@ void tearDown() { @Test void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.async(null)).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> McpServer.async((ServerMcpTransport) null)) + .isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); assertThatThrownBy(() -> McpServer.async(createMcpTransport()).serverInfo((McpSchema.Implementation) null)) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 91862821..5d2c8f69 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -96,11 +96,11 @@ public class McpAsyncServer { /** * Create a new McpAsyncServer with the given transport and capabilities. - * @param mcpTransportProvider The transport layer implementation for MCP communication. + * @param mcpTransportProvider The transport layer implementation for MCP + * communication. * @param features The MCP server supported features. */ - McpAsyncServer(McpServerTransportProvider mcpTransportProvider, - ObjectMapper objectMapper, + McpAsyncServer(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, McpServerFeatures.Async features) { this.delegate = new AsyncServerImpl(mcpTransportProvider, objectMapper, features); } @@ -319,6 +319,7 @@ void setProtocolVersions(List protocolVersions) { } private static class AsyncServerImpl extends McpAsyncServer { + private final McpServerTransportProvider mcpTransportProvider; private final ObjectMapper objectMapper; @@ -347,11 +348,11 @@ private static class AsyncServerImpl extends McpAsyncServer { /** * Create a new McpAsyncServer with the given transport and capabilities. - * @param mcpTransportProvider The transport layer implementation for MCP communication. + * @param mcpTransportProvider The transport layer implementation for MCP + * communication. * @param features The MCP server supported features. */ - AsyncServerImpl(McpServerTransportProvider mcpTransportProvider, - ObjectMapper objectMapper, + AsyncServerImpl(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, McpServerFeatures.Async features) { this.mcpTransportProvider = mcpTransportProvider; this.objectMapper = objectMapper; @@ -400,37 +401,36 @@ private static class AsyncServerImpl extends McpAsyncServer { List, Mono>> rootsChangeConsumers = features.rootsChangeConsumers(); if (Utils.isEmpty(rootsChangeConsumers)) { - rootsChangeConsumers = List.of((roots) -> Mono.fromRunnable(() -> logger - .warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); + rootsChangeConsumers = List.of((roots) -> Mono.fromRunnable(() -> logger.warn( + "Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); } notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - mcpTransportProvider.setSessionFactory(transport -> new McpServerSession( - UUID.randomUUID().toString(), - transport, - this::asyncInitializeRequestHandler, - Mono::empty, - requestHandlers, - notificationHandlers)); + mcpTransportProvider + .setSessionFactory(transport -> new McpServerSession(UUID.randomUUID().toString(), transport, + this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); } // --------------------------------------- // Lifecycle Management // --------------------------------------- - private Mono asyncInitializeRequestHandler(McpSchema.InitializeRequest initializeRequest) { + private Mono asyncInitializeRequestHandler( + McpSchema.InitializeRequest initializeRequest) { return Mono.defer(() -> { logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", initializeRequest.protocolVersion(), initializeRequest.capabilities(), initializeRequest.clientInfo()); - // The server MUST respond with the highest protocol version it supports if + // The server MUST respond with the highest protocol version it supports + // if // it does not support the requested (e.g. Client) version. String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { - // If the server supports the requested protocol version, it MUST respond + // If the server supports the requested protocol version, it MUST + // respond // with the same version. serverProtocolVersion = initializeRequest.protocolVersion(); } @@ -446,7 +446,8 @@ private Mono asyncInitializeRequestHandler(McpSchema } /** - * Get the server capabilities that define the supported features and functionality. + * Get the server capabilities that define the supported features and + * functionality. * @return The server capabilities */ public McpSchema.ServerCapabilities getServerCapabilities() { @@ -462,7 +463,8 @@ public McpSchema.Implementation getServerInfo() { } /** - * Get the client capabilities that define the supported features and functionality. + * Get the client capabilities that define the supported features and + * functionality. * @return The client capabilities */ @Deprecated @@ -517,12 +519,12 @@ private McpServerSession.NotificationHandler asyncRootsListChangedNotificationHa List, Mono>> rootsChangeConsumers) { return (exchange, params) -> listRoots().flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) - .flatMap(consumer -> consumer.apply(listRootsResult.roots())) - .onErrorResume(error -> { - logger.error("Error handling roots list change notification", error); - return Mono.empty(); - }) - .then()); + .flatMap(consumer -> consumer.apply(listRootsResult.roots())) + .onErrorResume(error -> { + logger.error("Error handling roots list change notification", error); + return Mono.empty(); + }) + .then()); } // --------------------------------------- @@ -552,7 +554,7 @@ public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistrati // Check for duplicate tool names if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolRegistration.tool().name()))) { return Mono - .error(new McpError("Tool with name '" + toolRegistration.tool().name() + "' already exists")); + .error(new McpError("Tool with name '" + toolRegistration.tool().name() + "' already exists")); } this.tools.add(toolRegistration); @@ -579,7 +581,8 @@ public Mono removeTool(String toolName) { } return Mono.defer(() -> { - boolean removed = this.tools.removeIf(toolRegistration -> toolRegistration.tool().name().equals(toolName)); + boolean removed = this.tools + .removeIf(toolRegistration -> toolRegistration.tool().name().equals(toolName)); if (removed) { logger.debug("Removed tool handler: {}", toolName); if (this.serverCapabilities.tools().listChanged()) { @@ -614,15 +617,15 @@ private McpServerSession.RequestHandler toolsCallRequestHandler( }); Optional toolRegistration = this.tools.stream() - .filter(tr -> callToolRequest.name().equals(tr.tool().name())) - .findAny(); + .filter(tr -> callToolRequest.name().equals(tr.tool().name())) + .findAny(); if (toolRegistration.isEmpty()) { return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); } return toolRegistration.map(tool -> tool.call().apply(callToolRequest.arguments())) - .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); + .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); }; } @@ -646,8 +649,8 @@ public Mono addResource(McpServerFeatures.AsyncResourceRegistration resour return Mono.defer(() -> { if (this.resources.putIfAbsent(resourceHandler.resource().uri(), resourceHandler) != null) { - return Mono - .error(new McpError("Resource with URI '" + resourceHandler.resource().uri() + "' already exists")); + return Mono.error(new McpError( + "Resource with URI '" + resourceHandler.resource().uri() + "' already exists")); } logger.debug("Added resource handler: {}", resourceHandler.resource().uri()); if (this.serverCapabilities.resources().listChanged()) { @@ -688,23 +691,24 @@ public Mono removeResource(String resourceUri) { * @return A Mono that completes when all clients have been notified */ public Mono notifyResourcesListChanged() { - McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, - McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); + McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification( + McpSchema.JSONRPC_VERSION, McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); } private McpServerSession.RequestHandler resourcesListRequestHandler() { return (exchange, params) -> { var resourceList = this.resources.values() - .stream() - .map(McpServerFeatures.AsyncResourceRegistration::resource) - .toList(); + .stream() + .map(McpServerFeatures.AsyncResourceRegistration::resource) + .toList(); return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); }; } private McpServerSession.RequestHandler resourceTemplateListRequestHandler() { - return (exchange, params) -> Mono.just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null)); + return (exchange, params) -> Mono + .just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null)); } @@ -741,10 +745,10 @@ public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegi return Mono.defer(() -> { McpServerFeatures.AsyncPromptRegistration registration = this.prompts - .putIfAbsent(promptRegistration.prompt().name(), promptRegistration); + .putIfAbsent(promptRegistration.prompt().name(), promptRegistration); if (registration != null) { - return Mono.error( - new McpError("Prompt with name '" + promptRegistration.prompt().name() + "' already exists")); + return Mono.error(new McpError( + "Prompt with name '" + promptRegistration.prompt().name() + "' already exists")); } logger.debug("Added prompt handler: {}", promptRegistration.prompt().name()); @@ -804,9 +808,9 @@ private McpServerSession.RequestHandler promptsList // }); var promptList = this.prompts.values() - .stream() - .map(McpServerFeatures.AsyncPromptRegistration::prompt) - .toList(); + .stream() + .map(McpServerFeatures.AsyncPromptRegistration::prompt) + .toList(); return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); }; @@ -833,8 +837,8 @@ private McpServerSession.RequestHandler promptsGetReq // --------------------------------------- /** - * Send a logging message notification to all connected clients. Messages below the - * current minimum logging level will be filtered out. + * Send a logging message notification to all connected clients. Messages below + * the current minimum logging level will be filtered out. * @param loggingMessageNotification The logging message to send * @return A Mono that completes when the notification has been sent */ @@ -856,8 +860,8 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN } /** - * Handles requests to set the minimum logging level. Messages below this level will - * not be sent. + * Handles requests to set the minimum logging level. Messages below this level + * will not be sent. * @return A handler that processes logging level change requests */ private McpServerSession.RequestHandler setLoggerRequestHandler() { @@ -876,11 +880,11 @@ private McpServerSession.RequestHandler setLoggerRequestHandler() { /** * Create a new message using the sampling capabilities of the client. The Model * Context Protocol (MCP) provides a standardized way for servers to request LLM - * sampling (“completions” or “generations”) from language models via clients. This - * flow allows clients to maintain control over model access, selection, and - * permissions while enabling servers to leverage AI capabilities—with no server API - * keys necessary. Servers can request text or image-based interactions and optionally - * include context from MCP servers in their prompts. + * sampling (“completions” or “generations”) from language models via clients. + * This flow allows clients to maintain control over model access, selection, and + * permissions while enabling servers to leverage AI capabilities—with no server + * API keys necessary. Servers can request text or image-based interactions and + * optionally include context from MCP servers in their prompts. * @param createMessageRequest The request to create a new message * @return A Mono that completes when the message has been created * @throws McpError if the client has not been initialized or does not support @@ -898,19 +902,21 @@ public Mono createMessage(McpSchema.CreateMessage } /** - * This method is package-private and used for test only. Should not be called by user - * code. + * This method is package-private and used for test only. Should not be called by + * user code. * @param protocolVersions the Client supported protocol versions. */ void setProtocolVersions(List protocolVersions) { this.protocolVersions = protocolVersions; } + } private static final class LegacyAsyncServer extends McpAsyncServer { + /** - * The MCP session implementation that manages bidirectional JSON-RPC communication - * between clients and servers. + * The MCP session implementation that manages bidirectional JSON-RPC + * communication between clients and servers. */ private final DefaultMcpSession mcpSession; @@ -995,8 +1001,8 @@ private static final class LegacyAsyncServer extends McpAsyncServer { List, Mono>> rootsChangeConsumers = features.rootsChangeConsumers(); if (Utils.isEmpty(rootsChangeConsumers)) { - rootsChangeConsumers = List.of((roots) -> Mono.fromRunnable(() -> logger - .warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); + rootsChangeConsumers = List.of((roots) -> Mono.fromRunnable(() -> logger.warn( + "Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); } notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, @@ -1021,12 +1027,14 @@ private DefaultMcpSession.RequestHandler asyncInitia initializeRequest.protocolVersion(), initializeRequest.capabilities(), initializeRequest.clientInfo()); - // The server MUST respond with the highest protocol version it supports if + // The server MUST respond with the highest protocol version it supports + // if // it does not support the requested (e.g. Client) version. String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { - // If the server supports the requested protocol version, it MUST respond + // If the server supports the requested protocol version, it MUST + // respond // with the same version. serverProtocolVersion = initializeRequest.protocolVersion(); } @@ -1042,7 +1050,8 @@ private DefaultMcpSession.RequestHandler asyncInitia } /** - * Get the server capabilities that define the supported features and functionality. + * Get the server capabilities that define the supported features and + * functionality. * @return The server capabilities */ public McpSchema.ServerCapabilities getServerCapabilities() { @@ -1058,7 +1067,8 @@ public McpSchema.Implementation getServerInfo() { } /** - * Get the client capabilities that define the supported features and functionality. + * Get the client capabilities that define the supported features and + * functionality. * @return The client capabilities */ public ClientCapabilities getClientCapabilities() { @@ -1112,12 +1122,12 @@ public Mono listRoots(String cursor) { private DefaultMcpSession.NotificationHandler asyncRootsListChangedNotificationHandler( List, Mono>> rootsChangeConsumers) { return params -> listRoots().flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) - .flatMap(consumer -> consumer.apply(listRootsResult.roots())) - .onErrorResume(error -> { - logger.error("Error handling roots list change notification", error); - return Mono.empty(); - }) - .then()); + .flatMap(consumer -> consumer.apply(listRootsResult.roots())) + .onErrorResume(error -> { + logger.error("Error handling roots list change notification", error); + return Mono.empty(); + }) + .then()); } // --------------------------------------- @@ -1147,7 +1157,7 @@ public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistrati // Check for duplicate tool names if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolRegistration.tool().name()))) { return Mono - .error(new McpError("Tool with name '" + toolRegistration.tool().name() + "' already exists")); + .error(new McpError("Tool with name '" + toolRegistration.tool().name() + "' already exists")); } this.tools.add(toolRegistration); @@ -1174,7 +1184,8 @@ public Mono removeTool(String toolName) { } return Mono.defer(() -> { - boolean removed = this.tools.removeIf(toolRegistration -> toolRegistration.tool().name().equals(toolName)); + boolean removed = this.tools + .removeIf(toolRegistration -> toolRegistration.tool().name().equals(toolName)); if (removed) { logger.debug("Removed tool handler: {}", toolName); if (this.serverCapabilities.tools().listChanged()) { @@ -1209,15 +1220,15 @@ private DefaultMcpSession.RequestHandler toolsCallRequestHandler }); Optional toolRegistration = this.tools.stream() - .filter(tr -> callToolRequest.name().equals(tr.tool().name())) - .findAny(); + .filter(tr -> callToolRequest.name().equals(tr.tool().name())) + .findAny(); if (toolRegistration.isEmpty()) { return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); } return toolRegistration.map(tool -> tool.call().apply(callToolRequest.arguments())) - .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); + .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); }; } @@ -1241,8 +1252,8 @@ public Mono addResource(McpServerFeatures.AsyncResourceRegistration resour return Mono.defer(() -> { if (this.resources.putIfAbsent(resourceHandler.resource().uri(), resourceHandler) != null) { - return Mono - .error(new McpError("Resource with URI '" + resourceHandler.resource().uri() + "' already exists")); + return Mono.error(new McpError( + "Resource with URI '" + resourceHandler.resource().uri() + "' already exists")); } logger.debug("Added resource handler: {}", resourceHandler.resource().uri()); if (this.serverCapabilities.resources().listChanged()) { @@ -1289,9 +1300,9 @@ public Mono notifyResourcesListChanged() { private DefaultMcpSession.RequestHandler resourcesListRequestHandler() { return params -> { var resourceList = this.resources.values() - .stream() - .map(McpServerFeatures.AsyncResourceRegistration::resource) - .toList(); + .stream() + .map(McpServerFeatures.AsyncResourceRegistration::resource) + .toList(); return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); }; } @@ -1334,10 +1345,10 @@ public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegi return Mono.defer(() -> { McpServerFeatures.AsyncPromptRegistration registration = this.prompts - .putIfAbsent(promptRegistration.prompt().name(), promptRegistration); + .putIfAbsent(promptRegistration.prompt().name(), promptRegistration); if (registration != null) { - return Mono.error( - new McpError("Prompt with name '" + promptRegistration.prompt().name() + "' already exists")); + return Mono.error(new McpError( + "Prompt with name '" + promptRegistration.prompt().name() + "' already exists")); } logger.debug("Added prompt handler: {}", promptRegistration.prompt().name()); @@ -1397,9 +1408,9 @@ private DefaultMcpSession.RequestHandler promptsLis // }); var promptList = this.prompts.values() - .stream() - .map(McpServerFeatures.AsyncPromptRegistration::prompt) - .toList(); + .stream() + .map(McpServerFeatures.AsyncPromptRegistration::prompt) + .toList(); return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); }; @@ -1426,8 +1437,8 @@ private DefaultMcpSession.RequestHandler promptsGetRe // --------------------------------------- /** - * Send a logging message notification to all connected clients. Messages below the - * current minimum logging level will be filtered out. + * Send a logging message notification to all connected clients. Messages below + * the current minimum logging level will be filtered out. * @param loggingMessageNotification The logging message to send * @return A Mono that completes when the notification has been sent */ @@ -1449,8 +1460,8 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN } /** - * Handles requests to set the minimum logging level. Messages below this level will - * not be sent. + * Handles requests to set the minimum logging level. Messages below this level + * will not be sent. * @return A handler that processes logging level change requests */ private DefaultMcpSession.RequestHandler setLoggerRequestHandler() { @@ -1471,11 +1482,11 @@ private DefaultMcpSession.RequestHandler setLoggerRequestHandler() { /** * Create a new message using the sampling capabilities of the client. The Model * Context Protocol (MCP) provides a standardized way for servers to request LLM - * sampling (“completions” or “generations”) from language models via clients. This - * flow allows clients to maintain control over model access, selection, and - * permissions while enabling servers to leverage AI capabilities—with no server API - * keys necessary. Servers can request text or image-based interactions and optionally - * include context from MCP servers in their prompts. + * sampling (“completions” or “generations”) from language models via clients. + * This flow allows clients to maintain control over model access, selection, and + * permissions while enabling servers to leverage AI capabilities—with no server + * API keys necessary. Servers can request text or image-based interactions and + * optionally include context from MCP servers in their prompts. * @param createMessageRequest The request to create a new message * @return A Mono that completes when the message has been created * @throws McpError if the client has not been initialized or does not support @@ -1500,8 +1511,8 @@ public Mono createMessage(McpSchema.CreateMessage } /** - * This method is package-private and used for test only. Should not be called by user - * code. + * This method is package-private and used for test only. Should not be called by + * user code. * @param protocolVersions the Client supported protocol versions. */ void setProtocolVersions(List protocolVersions) { @@ -1509,4 +1520,5 @@ void setProtocolVersions(List protocolVersions) { } } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index cff897dd..c2431665 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -152,6 +152,7 @@ class AsyncSpec { "1.0.0"); private final ServerMcpTransport transport; + private final McpServerTransportProvider transportProvider; private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; @@ -522,13 +523,13 @@ public AsyncSpec rootsChangeConsumers( * settings */ public McpAsyncServer build() { - var features = new McpServerFeatures.Async(this.serverInfo, - this.serverCapabilities, this.tools, this.resources, - this.resourceTemplates, this.prompts, this.rootsChangeConsumers); + var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, + this.resources, this.resourceTemplates, this.prompts, this.rootsChangeConsumers); if (this.transportProvider != null) { // FIXME: provide ObjectMapper configuration return new McpAsyncServer(this.transportProvider, new ObjectMapper(), features); - } else { + } + else { return new McpAsyncServer(this.transport, features); } } @@ -544,6 +545,7 @@ class SyncSpec { "1.0.0"); private final ServerMcpTransport transport; + private final McpServerTransportProvider transportProvider; private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; @@ -917,9 +919,9 @@ public SyncSpec rootsChangeConsumers(Consumer>... consumers public McpSyncServer build() { McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, this.tools, this.resources, this.resourceTemplates, this.prompts, this.rootsChangeConsumers); - McpServerFeatures.Async asyncFeatures = - McpServerFeatures.Async.fromSync(syncFeatures); - var asyncServer = this.transportProvider != null ? new McpAsyncServer(this.transportProvider, new ObjectMapper(), asyncFeatures) + McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures); + var asyncServer = this.transportProvider != null + ? new McpAsyncServer(this.transportProvider, new ObjectMapper(), asyncFeatures) : new McpAsyncServer(this.transport, asyncFeatures); return new McpSyncServer(asyncServer); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 0edd20b6..e95cf5bd 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -34,21 +34,22 @@ public class McpServerSession implements McpSession { private final McpServerTransport transport; - private final Sinks.One exchangeSink = Sinks.one(); + private final Sinks.One exchangeSink = Sinks.one(); + private final AtomicReference clientCapabilities = new AtomicReference<>(); + private final AtomicReference clientInfo = new AtomicReference<>(); - // 0 = uninitialized, 1 = initializing, 2 = initialized - private static final int UNINITIALIZED = 0; - private static final int INITIALIZING = 1; - private static final int INITIALIZED = 2; + private static final int STATE_UNINITIALIZED = 0; + + private static final int STATE_INITIALIZING = 1; + + private static final int STATE_INITIALIZED = 2; - private final AtomicInteger state = new AtomicInteger(UNINITIALIZED); + private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED); - public McpServerSession(String id, McpServerTransport transport, - InitRequestHandler initHandler, - InitNotificationHandler initNotificationHandler, - Map> requestHandlers, + public McpServerSession(String id, McpServerTransport transport, InitRequestHandler initHandler, + InitNotificationHandler initNotificationHandler, Map> requestHandlers, Map notificationHandlers) { this.id = id; this.transport = transport; @@ -155,18 +156,17 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR Mono resultMono; if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { // TODO handle situation where already initialized! - McpSchema.InitializeRequest initializeRequest = - transport.unmarshalFrom(request.params(), + McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(request.params(), new TypeReference() { }); - this.state.lazySet(INITIALIZING); + this.state.lazySet(STATE_INITIALIZING); this.init(initializeRequest.capabilities(), initializeRequest.clientInfo()); resultMono = this.initRequestHandler.handle(initializeRequest); } else { // TODO handle errors for communication to this session without - // initialization happening first + // initialization happening first var handler = this.requestHandlers.get(request.method()); if (handler == null) { MethodNotFoundError error = getMethodNotFoundError(request.method()); @@ -194,7 +194,7 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification) { return Mono.defer(() -> { if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) { - this.state.lazySet(INITIALIZED); + this.state.lazySet(STATE_INITIALIZED); exchangeSink.tryEmitValue(new McpServerExchange(this, clientCapabilities.get(), clientInfo.get())); return this.initNotificationHandler.handle(); } @@ -204,9 +204,7 @@ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification noti logger.error("No handler registered for notification method: {}", notification.method()); return Mono.empty(); } - return this.exchangeSink.asMono() - .flatMap(exchange -> - handler.handle(exchange, notification.params())); + return this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, notification.params())); }); } @@ -234,24 +232,34 @@ public void close() { } public interface InitRequestHandler { + Mono handle(McpSchema.InitializeRequest initializeRequest); + } public interface InitNotificationHandler { + Mono handle(); + } public interface NotificationHandler { + Mono handle(McpServerExchange exchange, Object params); + } public interface RequestHandler { + Mono handle(McpServerExchange exchange, Object params); + } @FunctionalInterface public interface Factory { + McpServerSession create(McpServerTransport sessionTransport); + } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java index 77ecc043..41b07fdb 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java @@ -29,4 +29,5 @@ default void close() { * @return a {@link Mono} that completes when the connection has been closed. */ Mono closeGracefully(); + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index dcc103b5..e8b24c7c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -67,7 +67,8 @@ void tearDown() { @Test void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.async(null)).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> McpServer.async((ServerMcpTransport) null)) + .isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); assertThatThrownBy(() -> McpServer.async(createMcpTransport()).serverInfo((McpSchema.Implementation) null)) From 517ee3c5770c6e0ca2239c97ed880e9ba75cadf3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Thu, 13 Mar 2025 15:24:27 +0100 Subject: [PATCH 06/20] Support specifying new handlers in McpServer spec MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../WebFluxSseServerTransportProvider.java | 3 + .../WebFluxSseIntegrationTests.java | 276 ++++++----- .../legacy/WebFluxSseIntegrationTests.java | 459 ++++++++++++++++++ .../server/AbstractMcpSyncServerTests.java | 2 +- .../server/McpAsyncServer.java | 36 +- .../McpAsyncServerExchange.java} | 13 +- .../server/McpServer.java | 113 ++++- .../server/McpServerFeatures.java | 16 +- .../server/McpSyncServerExchange.java | 29 ++ .../spec/McpServerSession.java | 9 +- .../server/AbstractMcpSyncServerTests.java | 2 +- .../server/BaseMcpAsyncServerTests.java | 5 + 12 files changed, 788 insertions(+), 175 deletions(-) create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/legacy/WebFluxSseIntegrationTests.java rename mcp/src/main/java/io/modelcontextprotocol/{spec/McpServerExchange.java => server/McpAsyncServerExchange.java} (89%) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/BaseMcpAsyncServerTests.java diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 9138d6f4..13f5da31 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -293,6 +293,9 @@ private Mono handleMessage(ServerRequest request) { McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); return session.handle(message).flatMap(response -> ServerResponse.ok().build()).onErrorResume(error -> { logger.error("Error processing message: {}", error.getMessage()); + // TODO: instead of signalling the error, just respond with 200 OK + // - the error is signalled on the SSE connection + // return ServerResponse.ok().build(); return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) .bodyValue(new McpError(error.getMessage())); }); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 4cd24c62..d8f56d04 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -17,6 +17,7 @@ import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServerFeatures; import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -55,16 +56,16 @@ public class WebFluxSseIntegrationTests { private DisposableServer httpServer; - private WebFluxSseServerTransport mcpServerTransport; + private WebFluxSseServerTransportProvider mcpServerTransportProvider; ConcurrentHashMap clientBulders = new ConcurrentHashMap<>(); @BeforeEach public void before() { - this.mcpServerTransport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); - HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransport.getRouterFunction()); + HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransportProvider.getRouterFunction()); ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); @@ -84,89 +85,109 @@ public void after() { // --------------------------------------- // Sampling Tests // --------------------------------------- - @Test - void testCreateMessageWithoutInitialization() { - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized. Call the initialize method first!"); - }); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateMessageWithoutSamplingCapabilities(String clientType) { - - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - - var clientBuilder = clientBulders.get(clientType); - - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build(); - - InitializeResult initResult = client.initialize(); - assertThat(initResult).isNotNull(); - - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with sampling capabilities"); - }); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateMessageSuccess(String clientType) throws InterruptedException { - - var clientBuilder = clientBulders.get(clientType); - - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - InitializeResult initResult = client.initialize(); - assertThat(initResult).isNotNull(); - - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - } + // TODO implement within a tool execution + // @Test + // void testCreateMessageWithoutInitialization() { + // var mcpAsyncServer = + // McpServer.async(mcpServerTransportProvider).serverInfo("test-server", + // "1.0.0").build(); + // + // var messages = List + // .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new + // McpSchema.TextContent("Test message"))); + // var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + // + // var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, + // McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), + // Map.of()); + // + // StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error + // -> { + // assertThat(error).isInstanceOf(McpError.class) + // .hasMessage("Client must be initialized. Call the initialize method first!"); + // }); + // } + // + // @ParameterizedTest(name = "{0} : {displayName} ") + // @ValueSource(strings = { "httpclient", "webflux" }) + // void testCreateMessageWithoutSamplingCapabilities(String clientType) { + // + // var mcpAsyncServer = + // McpServer.async(mcpServerTransportProvider).serverInfo("test-server", + // "1.0.0").build(); + // + // var clientBuilder = clientBulders.get(clientType); + // + // var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", + // "0.0.0")).build(); + // + // InitializeResult initResult = client.initialize(); + // assertThat(initResult).isNotNull(); + // + // var messages = List + // .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new + // McpSchema.TextContent("Test message"))); + // var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + // + // var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, + // McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), + // Map.of()); + // + // StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error + // -> { + // assertThat(error).isInstanceOf(McpError.class) + // .hasMessage("Client must be configured with sampling capabilities"); + // }); + // } + // + // @ParameterizedTest(name = "{0} : {displayName} ") + // @ValueSource(strings = { "httpclient", "webflux" }) + // void testCreateMessageSuccess(String clientType) throws InterruptedException { + // + // var clientBuilder = clientBulders.get(clientType); + // + // var mcpAsyncServer = + // McpServer.async(mcpServerTransportProvider).serverInfo("test-server", + // "1.0.0").build(); + // + // Function samplingHandler = request -> { + // assertThat(request.messages()).hasSize(1); + // assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + // + // return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test + // message"), "MockModelName", + // CreateMessageResult.StopReason.STOP_SEQUENCE); + // }; + // + // var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", + // "0.0.0")) + // .capabilities(ClientCapabilities.builder().sampling().build()) + // .sampling(samplingHandler) + // .build(); + // + // InitializeResult initResult = client.initialize(); + // assertThat(initResult).isNotNull(); + // + // var messages = List + // .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new + // McpSchema.TextContent("Test message"))); + // var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + // + // var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, + // McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), + // Map.of()); + // + // StepVerifier.create(mcpAsyncServer.createMessage(request)).consumeNextWith(result + // -> { + // assertThat(result).isNotNull(); + // assertThat(result.role()).isEqualTo(Role.USER); + // assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + // assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test + // message"); + // assertThat(result.model()).isEqualTo("MockModelName"); + // assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + // }).verifyComplete(); + // } // --------------------------------------- // Roots Tests @@ -179,8 +200,8 @@ void testRootsSuccess(String clientType) { List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) @@ -192,8 +213,6 @@ void testRootsSuccess(String clientType) { assertThat(rootsRef.get()).isNull(); - assertThat(mcpServer.listRoots().roots()).containsAll(roots); - mcpClient.rootsListChangedNotification(); await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { @@ -219,30 +238,39 @@ void testRootsSuccess(String clientType) { mcpServer.close(); } - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsWithoutCapability(String clientType) { - var clientBuilder = clientBulders.get(clientType); - - var mcpServer = McpServer.sync(mcpServerTransport).rootsChangeConsumer(rootsUpdate -> { - }).build(); - - // Create client without roots capability - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()) // No - // roots - // capability - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Attempt to list roots should fail - assertThatThrownBy(() -> mcpServer.listRoots().roots()).isInstanceOf(McpError.class) - .hasMessage("Roots not supported"); - - mcpClient.close(); - mcpServer.close(); - } + // @ParameterizedTest(name = "{0} : {displayName} ") + // @ValueSource(strings = { "httpclient", "webflux" }) + // void testRootsWithoutCapability(String clientType) { + // var clientBuilder = clientBulders.get(clientType); + // AtomicReference errorRef = new AtomicReference<>(); + // + // var mcpServer = + // McpServer.sync(mcpServerTransportProvider) + // // TODO: implement tool handling and try to list roots + // .tool(tool, (exchange, args) -> { + // try { + // exchange.listRoots(); + // } catch (Exception e) { + // errorRef.set(e); + // } + // }).build(); + // + // // Create client without roots capability + // var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()) // + // No + // // roots + // // capability + // .build(); + // + // InitializeResult initResult = mcpClient.initialize(); + // assertThat(initResult).isNotNull(); + // + // assertThat(errorRef.get()).isInstanceOf(McpError.class).hasMessage("Roots not + // supported"); + // + // mcpClient.close(); + // mcpServer.close(); + // } @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) @@ -250,8 +278,8 @@ void testRootsWithEmptyRootsList(String clientType) { var clientBuilder = clientBulders.get(clientType); AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) @@ -273,7 +301,7 @@ void testRootsWithEmptyRootsList(String clientType) { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsWithMultipleConsumers(String clientType) { + void testRootsWithMultipleHandlers(String clientType) { var clientBuilder = clientBulders.get(clientType); List roots = List.of(new Root("uri1://", "root1")); @@ -281,9 +309,9 @@ void testRootsWithMultipleConsumers(String clientType) { AtomicReference> rootsRef1 = new AtomicReference<>(); AtomicReference> rootsRef2 = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef1.set(rootsUpdate)) - .rootsChangeConsumer(rootsUpdate -> rootsRef2.set(rootsUpdate)) + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) .build(); var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) @@ -313,8 +341,8 @@ void testRootsServerCloseWithActiveSubscription(String clientType) { List roots = List.of(new Root("uri1://", "root1")); AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) @@ -368,7 +396,7 @@ void testToolCallSuccess(String clientType) { return callResponse; }); - var mcpServer = McpServer.sync(mcpServerTransport) + var mcpServer = McpServer.sync(mcpServerTransportProvider) .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(tool1) .build(); @@ -408,7 +436,7 @@ void testToolListChangeHandlingSuccess(String clientType) { return callResponse; }); - var mcpServer = McpServer.sync(mcpServerTransport) + var mcpServer = McpServer.sync(mcpServerTransportProvider) .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(tool1) .build(); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/legacy/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/legacy/WebFluxSseIntegrationTests.java new file mode 100644 index 00000000..981e114c --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/legacy/WebFluxSseIntegrationTests.java @@ -0,0 +1,459 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server.legacy; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; +import reactor.test.StepVerifier; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunctions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.awaitility.Awaitility.await; + +public class WebFluxSseIntegrationTests { + + private static final int PORT = 8182; + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private DisposableServer httpServer; + + private WebFluxSseServerTransport mcpServerTransport; + + ConcurrentHashMap clientBulders = new ConcurrentHashMap<>(); + + @BeforeEach + public void before() { + + this.mcpServerTransport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransport.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + + clientBulders.put("httpclient", McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT))); + clientBulders.put("webflux", + McpClient.sync(new WebFluxSseClientTransport(WebClient.builder().baseUrl("http://localhost:" + PORT)))); + + } + + @AfterEach + public void after() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + + // --------------------------------------- + // Sampling Tests + // --------------------------------------- + @Test + void testCreateMessageWithoutInitialization() { + var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); + + var messages = List.of(new McpSchema.SamplingMessage(Role.USER, new McpSchema.TextContent("Test message"))); + var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + + var request = new CreateMessageRequest(messages, modelPrefs, null, + CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); + + StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized. Call the initialize method first!"); + }); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageWithoutSamplingCapabilities(String clientType) { + + var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); + + var clientBuilder = clientBulders.get(clientType); + + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build(); + + InitializeResult initResult = client.initialize(); + assertThat(initResult).isNotNull(); + + var messages = List.of(new McpSchema.SamplingMessage(Role.USER, new McpSchema.TextContent("Test message"))); + var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + + var request = new CreateMessageRequest(messages, modelPrefs, null, + CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); + + StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + }); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageSuccess(String clientType) throws InterruptedException { + + var clientBuilder = clientBulders.get(clientType); + + var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + InitializeResult initResult = client.initialize(); + assertThat(initResult).isNotNull(); + + var messages = List.of(new McpSchema.SamplingMessage(Role.USER, new McpSchema.TextContent("Test message"))); + var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + + var request = new CreateMessageRequest(messages, modelPrefs, null, + CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); + + StepVerifier.create(mcpAsyncServer.createMessage(request)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsSuccess(String clientType) { + var clientBuilder = clientBulders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransport) + .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + assertThat(mcpServer.listRoots().roots()).containsAll(roots); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Remove a root + mcpClient.removeRoot(roots.get(0).uri()); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); + }); + + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsWithoutCapability(String clientType) { + var clientBuilder = clientBulders.get(clientType); + + var mcpServer = McpServer.sync(mcpServerTransport).rootsChangeConsumer(rootsUpdate -> { + }).build(); + + // Create client without roots capability + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()) // No + // roots + // capability + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Attempt to list roots should fail + assertThatThrownBy(() -> mcpServer.listRoots().roots()).isInstanceOf(McpError.class) + .hasMessage("Roots not supported"); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsWithEmptyRootsList(String clientType) { + var clientBuilder = clientBulders.get(clientType); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransport) + .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(List.of()) // Empty roots list + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsWithMultipleConsumers(String clientType) { + var clientBuilder = clientBulders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef1 = new AtomicReference<>(); + AtomicReference> rootsRef2 = new AtomicReference<>(); + + var mcpServer = McpServer.sync(mcpServerTransport) + .rootsChangeConsumer(rootsUpdate -> rootsRef1.set(rootsUpdate)) + .rootsChangeConsumer(rootsUpdate -> rootsRef2.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsServerCloseWithActiveSubscription(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransport) + .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Close server while subscription is active + mcpServer.close(); + + // Verify client can handle server closure gracefully + mcpClient.close(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolCallSuccess(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( + new Tool("tool1", "tool1 description", emptyJsonSchema), request -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpServerTransport) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + var mcpClient = clientBuilder.build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolListChangeHandlingSuccess(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( + new Tool("tool1", "tool1 description", emptyJsonSchema), request -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpServerTransport) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + rootsRef.set(toolsUpdate); + }).build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + mcpServer.notifyToolsListChanged(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); + }); + + // Remove a tool + mcpServer.removeTool("tool1"); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + + // Add a new tool + McpServerFeatures.SyncToolRegistration tool2 = new McpServerFeatures.SyncToolRegistration( + new Tool("tool2", "tool2 description", emptyJsonSchema), request -> callResponse); + + mcpServer.addTool(tool2); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); + }); + + mcpClient.close(); + mcpServer.close(); + } + +} diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index f8b95750..af147f9d 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -64,7 +64,7 @@ void tearDown() { @Test void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.sync(null)).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> McpServer.sync((ServerMcpTransport) null)).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); assertThatThrownBy(() -> McpServer.sync(createMcpTransport()).serverInfo(null)) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 5d2c8f69..c51d7b9c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -12,7 +12,9 @@ import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; +import java.util.function.BiFunction; import java.util.function.Function; +import java.util.stream.Collectors; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; @@ -398,11 +400,14 @@ private static class AsyncServerImpl extends McpAsyncServer { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); - List, Mono>> rootsChangeConsumers = features.rootsChangeConsumers(); + List, Mono>> rootsChangeConsumers = features + .rootsChangeConsumers(); if (Utils.isEmpty(rootsChangeConsumers)) { - rootsChangeConsumers = List.of((roots) -> Mono.fromRunnable(() -> logger.warn( - "Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); + rootsChangeConsumers = List.of((exchange, + roots) -> Mono.fromRunnable(() -> logger.warn( + "Roots list changed notification, but no consumers provided. Roots list changed: {}", + roots))); } notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, @@ -516,15 +521,15 @@ public Mono listRoots(String cursor) { } private McpServerSession.NotificationHandler asyncRootsListChangedNotificationHandler( - List, Mono>> rootsChangeConsumers) { - return (exchange, - params) -> listRoots().flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) - .flatMap(consumer -> consumer.apply(listRootsResult.roots())) - .onErrorResume(error -> { - logger.error("Error handling roots list change notification", error); - return Mono.empty(); - }) - .then()); + List, Mono>> rootsChangeConsumers) { + return (exchange, params) -> exchange.listRoots() + .flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) + .flatMap(consumer -> consumer.apply(exchange, listRootsResult.roots())) + .onErrorResume(error -> { + logger.error("Error handling roots list change notification", error); + return Mono.empty(); + }) + .then()); } // --------------------------------------- @@ -998,7 +1003,12 @@ private static final class LegacyAsyncServer extends McpAsyncServer { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (params) -> Mono.empty()); - List, Mono>> rootsChangeConsumers = features.rootsChangeConsumers(); + List, Mono>> rootsChangeHandlers = features + .rootsChangeConsumers(); + + List, Mono>> rootsChangeConsumers = rootsChangeHandlers.stream() + .map(handler -> (Function, Mono>) (roots) -> handler.apply(null, roots)) + .toList(); if (Utils.isEmpty(rootsChangeConsumers)) { rootsChangeConsumers = List.of((roots) -> Mono.fromRunnable(() -> logger.warn( diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java similarity index 89% rename from mcp/src/main/java/io/modelcontextprotocol/spec/McpServerExchange.java rename to mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index a8f54a2d..8959c293 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -1,13 +1,12 @@ -package io.modelcontextprotocol.spec; +package io.modelcontextprotocol.server; import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; import reactor.core.publisher.Mono; -public class McpServerExchange { - - // map(roots) - // map(resource_subscription) - // initialization state +public class McpAsyncServerExchange { private final McpServerSession session; @@ -15,7 +14,7 @@ public class McpServerExchange { private final McpSchema.Implementation clientInfo; - public McpServerExchange(McpServerSession session, McpSchema.ClientCapabilities clientCapabilities, + public McpAsyncServerExchange(McpServerSession session, McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { this.session = session; this.clientCapabilities = clientCapabilities; diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index c2431665..840631a0 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -5,11 +5,15 @@ package io.modelcontextprotocol.server; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; +import java.util.stream.Collectors; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpSchema; @@ -121,18 +125,27 @@ public interface McpServer { * concurrent operations. * @param transport The transport layer implementation for MCP communication * @return A new instance of {@link SyncSpec} for configuring the server. + * @deprecated This method will be removed in 0.9.0. Use + * {@link #sync(McpServerTransportProvider)} instead. */ + @Deprecated static SyncSpec sync(ServerMcpTransport transport) { return new SyncSpec(transport); } + static SyncSpec sync(McpServerTransportProvider transportProvider) { + return new SyncSpec(transportProvider); + } + /** * Starts building an asynchronous MCP server that provides blocking operations. * Asynchronous servers can handle multiple requests concurrently using a functional * paradigm with non-blocking server transports, making them more efficient for * high-concurrency scenarios but more complex to implement. * @param transport The transport layer implementation for MCP communication - * @return A new instance of {@link SyncSpec} for configuring the server. + * @return A new instance of {@link AsyncSpec} for configuring the server. + * @deprecated This method will be removed in 0.9.0. Use + * {@link #async(McpServerTransportProvider)} instead. */ @Deprecated static AsyncSpec async(ServerMcpTransport transport) { @@ -155,6 +168,8 @@ class AsyncSpec { private final McpServerTransportProvider transportProvider; + private ObjectMapper objectMapper; + private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; private McpSchema.ServerCapabilities serverCapabilities; @@ -188,7 +203,7 @@ class AsyncSpec { */ private final Map prompts = new HashMap<>(); - private final List, Mono>> rootsChangeConsumers = new ArrayList<>(); + private final List, Mono>> rootsChangeHandlers = new ArrayList<>(); private AsyncSpec(McpServerTransportProvider transportProvider) { Assert.notNull(transportProvider, "Transport provider must not be null"); @@ -480,10 +495,19 @@ public AsyncSpec prompts(McpServerFeatures.AsyncPromptRegistration... prompts) { * @param consumer The consumer to register. Must not be null. * @return This builder instance for method chaining * @throws IllegalArgumentException if consumer is null + * @deprecated This method will be removed in 0.9.0. Use + * {@link #rootsChangeHandler(BiFunction)} instead. */ + @Deprecated public AsyncSpec rootsChangeConsumer(Function, Mono> consumer) { Assert.notNull(consumer, "Consumer must not be null"); - this.rootsChangeConsumers.add(consumer); + return this.rootsChangeHandler((exchange, roots) -> consumer.apply(roots)); + } + + public AsyncSpec rootsChangeHandler( + BiFunction, Mono> handler) { + Assert.notNull(handler, "Consumer must not be null"); + this.rootsChangeHandlers.add(handler); return this; } @@ -494,10 +518,22 @@ public AsyncSpec rootsChangeConsumer(Function, Mono> * @param consumers The list of consumers to register. Must not be null. * @return This builder instance for method chaining * @throws IllegalArgumentException if consumers is null + * @deprecated This method will be removed in 0.9.0. Use + * {@link #rootsChangeHandlers(List)} instead. */ + @Deprecated public AsyncSpec rootsChangeConsumers(List, Mono>> consumers) { Assert.notNull(consumers, "Consumers list must not be null"); - this.rootsChangeConsumers.addAll(consumers); + return this.rootsChangeHandlers(consumers.stream() + .map(consumer -> (BiFunction, Mono>) ( + McpAsyncServerExchange exchange, List roots) -> consumer.apply(roots)) + .collect(Collectors.toList())); + } + + public AsyncSpec rootsChangeHandlers( + List, Mono>> handlers) { + Assert.notNull(handlers, "Handlers list must not be null"); + this.rootsChangeHandlers.addAll(handlers); return this; } @@ -508,12 +544,22 @@ public AsyncSpec rootsChangeConsumers(List, Mono, Mono>... consumers) { - for (Function, Mono> consumer : consumers) { - this.rootsChangeConsumers.add(consumer); - } + return this.rootsChangeConsumers(Arrays.asList(consumers)); + } + + public AsyncSpec rootsChangeHandlers( + @SuppressWarnings("unchecked") BiFunction, Mono>... handlers) { + return this.rootsChangeHandlers(Arrays.asList(handlers)); + } + + public AsyncSpec objectMapper(ObjectMapper objectMapper) { + this.objectMapper = objectMapper; return this; } @@ -524,10 +570,10 @@ public AsyncSpec rootsChangeConsumers( */ public McpAsyncServer build() { var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, - this.resources, this.resourceTemplates, this.prompts, this.rootsChangeConsumers); + this.resources, this.resourceTemplates, this.prompts, this.rootsChangeHandlers); if (this.transportProvider != null) { - // FIXME: provide ObjectMapper configuration - return new McpAsyncServer(this.transportProvider, new ObjectMapper(), features); + var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + return new McpAsyncServer(this.transportProvider, mapper, features); } else { return new McpAsyncServer(this.transport, features); @@ -548,6 +594,8 @@ class SyncSpec { private final McpServerTransportProvider transportProvider; + private ObjectMapper objectMapper; + private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; private McpSchema.ServerCapabilities serverCapabilities; @@ -581,7 +629,7 @@ class SyncSpec { */ private final Map prompts = new HashMap<>(); - private final List>> rootsChangeConsumers = new ArrayList<>(); + private final List>> rootsChangeHandlers = new ArrayList<>(); private SyncSpec(McpServerTransportProvider transportProvider) { Assert.notNull(transportProvider, "Transport provider must not be null"); @@ -875,10 +923,18 @@ public SyncSpec prompts(McpServerFeatures.SyncPromptRegistration... prompts) { * @param consumer The consumer to register. Must not be null. * @return This builder instance for method chaining * @throws IllegalArgumentException if consumer is null + * @deprecated This method will be removed in 0.9.0. Use + * {@link #rootsChangeHandler(BiConsumer)}. */ + @Deprecated public SyncSpec rootsChangeConsumer(Consumer> consumer) { Assert.notNull(consumer, "Consumer must not be null"); - this.rootsChangeConsumers.add(consumer); + return this.rootsChangeHandler((exchange, roots) -> consumer.accept(roots)); + } + + public SyncSpec rootsChangeHandler(BiConsumer> handler) { + Assert.notNull(handler, "Consumer must not be null"); + this.rootsChangeHandlers.add(handler); return this; } @@ -889,10 +945,21 @@ public SyncSpec rootsChangeConsumer(Consumer> consumer) { * @param consumers The list of consumers to register. Must not be null. * @return This builder instance for method chaining * @throws IllegalArgumentException if consumers is null + * @deprecated This method will be removed in 0.9.0. Use + * {@link #rootsChangeHandlers(List)}. */ + @Deprecated public SyncSpec rootsChangeConsumers(List>> consumers) { Assert.notNull(consumers, "Consumers list must not be null"); - this.rootsChangeConsumers.addAll(consumers); + return this.rootsChangeHandlers(consumers.stream() + .map(consumer -> (BiConsumer>) (exchange, roots) -> consumer + .accept(roots)) + .collect(Collectors.toList())); + } + + public SyncSpec rootsChangeHandlers(List>> handlers) { + Assert.notNull(handlers, "Handlers list must not be null"); + this.rootsChangeHandlers.addAll(handlers); return this; } @@ -903,11 +970,20 @@ public SyncSpec rootsChangeConsumers(List>> consum * @param consumers The consumers to register. Must not be null. * @return This builder instance for method chaining * @throws IllegalArgumentException if consumers is null + * @deprecated This method will * be removed in 0.9.0. Use + * {@link #rootsChangeHandlers(BiConsumer[])}. */ + @Deprecated public SyncSpec rootsChangeConsumers(Consumer>... consumers) { - for (Consumer> consumer : consumers) { - this.rootsChangeConsumers.add(consumer); - } + return this.rootsChangeConsumers(Arrays.asList(consumers)); + } + + public SyncSpec rootsChangeHandlers(BiConsumer>... handlers) { + return this.rootsChangeHandlers(List.of(handlers)); + } + + public SyncSpec objectMapper(ObjectMapper objectMapper) { + this.objectMapper = objectMapper; return this; } @@ -918,10 +994,11 @@ public SyncSpec rootsChangeConsumers(Consumer>... consumers */ public McpSyncServer build() { McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, - this.tools, this.resources, this.resourceTemplates, this.prompts, this.rootsChangeConsumers); + this.tools, this.resources, this.resourceTemplates, this.prompts, this.rootsChangeHandlers); McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures); + var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); var asyncServer = this.transportProvider != null - ? new McpAsyncServer(this.transportProvider, new ObjectMapper(), asyncFeatures) + ? new McpAsyncServer(this.transportProvider, mapper, asyncFeatures) : new McpAsyncServer(this.transport, asyncFeatures); return new McpSyncServer(asyncServer); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java index c8f8399a..7e4e140f 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -8,7 +8,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.Consumer; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; import java.util.function.Function; import io.modelcontextprotocol.spec.McpSchema; @@ -40,7 +41,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s List tools, Map resources, List resourceTemplates, Map prompts, - List, Mono>> rootsChangeConsumers) { + List, Mono>> rootsChangeConsumers) { /** * Create an instance and validate the arguments. @@ -57,7 +58,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s List tools, Map resources, List resourceTemplates, Map prompts, - List, Mono>> rootsChangeConsumers) { + List, Mono>> rootsChangeConsumers) { Assert.notNull(serverInfo, "Server info must not be null"); @@ -104,10 +105,11 @@ static Async fromSync(Sync syncSpec) { prompts.put(key, AsyncPromptRegistration.fromSync(prompt)); }); - List, Mono>> rootChangeConsumers = new ArrayList<>(); + List, Mono>> rootChangeConsumers = new ArrayList<>(); for (var rootChangeConsumer : syncSpec.rootsChangeConsumers()) { - rootChangeConsumers.add(list -> Mono.fromRunnable(() -> rootChangeConsumer.accept(list)) + rootChangeConsumers.add((exchange, list) -> Mono + .fromRunnable(() -> rootChangeConsumer.accept(new McpSyncServerExchange(exchange), list)) .subscribeOn(Schedulers.boundedElastic())); } @@ -133,7 +135,7 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se Map resources, List resourceTemplates, Map prompts, - List>> rootsChangeConsumers) { + List>> rootsChangeConsumers) { /** * Create an instance and validate the arguments. @@ -151,7 +153,7 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se Map resources, List resourceTemplates, Map prompts, - List>> rootsChangeConsumers) { + List>> rootsChangeConsumers) { Assert.notNull(serverInfo, "Server info must not be null"); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java new file mode 100644 index 00000000..09d87111 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java @@ -0,0 +1,29 @@ +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.spec.McpSchema; + +public class McpSyncServerExchange { + + private final McpAsyncServerExchange exchange; + + public McpSyncServerExchange(McpAsyncServerExchange exchange) { + this.exchange = exchange; + } + + public McpSchema.CreateMessageResult createMessage(McpSchema.CreateMessageRequest createMessageRequest) { + return this.exchange.createMessage(createMessageRequest).block(); + } + + private static final TypeReference LIST_ROOTS_RESULT_TYPE_REF = new TypeReference<>() { + }; + + public McpSchema.ListRootsResult listRoots() { + return this.exchange.listRoots().block(); + } + + public McpSchema.ListRootsResult listRoots(String cursor) { + return this.exchange.listRoots(cursor).block(); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index e95cf5bd..8304abd6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -8,6 +8,7 @@ import java.util.concurrent.atomic.AtomicReference; import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.server.McpAsyncServerExchange; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; @@ -34,7 +35,7 @@ public class McpServerSession implements McpSession { private final McpServerTransport transport; - private final Sinks.One exchangeSink = Sinks.one(); + private final Sinks.One exchangeSink = Sinks.one(); private final AtomicReference clientCapabilities = new AtomicReference<>(); @@ -195,7 +196,7 @@ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification noti return Mono.defer(() -> { if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) { this.state.lazySet(STATE_INITIALIZED); - exchangeSink.tryEmitValue(new McpServerExchange(this, clientCapabilities.get(), clientInfo.get())); + exchangeSink.tryEmitValue(new McpAsyncServerExchange(this, clientCapabilities.get(), clientInfo.get())); return this.initNotificationHandler.handle(); } @@ -245,13 +246,13 @@ public interface InitNotificationHandler { public interface NotificationHandler { - Mono handle(McpServerExchange exchange, Object params); + Mono handle(McpAsyncServerExchange exchange, Object params); } public interface RequestHandler { - Mono handle(McpServerExchange exchange, Object params); + Mono handle(McpAsyncServerExchange exchange, Object params); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index bdcd7ae3..d76cf8e5 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -65,7 +65,7 @@ void tearDown() { @Test void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.sync(null)).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> McpServer.sync((ServerMcpTransport) null)).isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport must not be null"); assertThatThrownBy(() -> McpServer.sync(createMcpTransport()).serverInfo(null)) diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/BaseMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/BaseMcpAsyncServerTests.java new file mode 100644 index 00000000..208bcb71 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/BaseMcpAsyncServerTests.java @@ -0,0 +1,5 @@ +package io.modelcontextprotocol.server; + +public abstract class BaseMcpAsyncServerTests { + +} From ed336978ac9906a12b5e08e8181996d4e7948923 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Fri, 14 Mar 2025 12:37:38 +0100 Subject: [PATCH 07/20] Replacing Registration classes with Specification classes --- .../server/McpAsyncServer.java | 144 ++- .../server/McpServer.java | 947 ++++++++++++++++-- .../server/McpServerFeatures.java | 325 +++++- .../server/AbstractMcpAsyncServerTests.java | 1 + 4 files changed, 1236 insertions(+), 181 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index c51d7b9c..bc1e5e12 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -14,7 +14,6 @@ import java.util.concurrent.CopyOnWriteArrayList; import java.util.function.BiFunction; import java.util.function.Function; -import java.util.stream.Collectors; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; @@ -333,13 +332,13 @@ private static class AsyncServerImpl extends McpAsyncServer { /** * Thread-safe list of tool handlers that can be modified at runtime. */ - private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); + private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); - private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); + private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); - private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); + private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; @@ -540,15 +539,27 @@ private McpServerSession.NotificationHandler asyncRootsListChangedNotificationHa * Add a new tool registration at runtime. * @param toolRegistration The tool registration to add * @return Mono that completes when clients have been notified of the change + * @deprecated This method will be removed in 0.9.0. Use + * {@link #addTool(McpServerFeatures.AsyncToolSpecification)}. */ + @Deprecated public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistration) { - if (toolRegistration == null) { - return Mono.error(new McpError("Tool registration must not be null")); + return this.addTool(toolRegistration.toSpecification()); + } + + /** + * Add a new tool registration at runtime. + * @param toolSpecification The tool registration to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) { + if (toolSpecification == null) { + return Mono.error(new McpError("Tool specification must not be null")); } - if (toolRegistration.tool() == null) { + if (toolSpecification.tool() == null) { return Mono.error(new McpError("Tool must not be null")); } - if (toolRegistration.call() == null) { + if (toolSpecification.call() == null) { return Mono.error(new McpError("Tool call handler must not be null")); } if (this.serverCapabilities.tools() == null) { @@ -557,13 +568,13 @@ public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistrati return Mono.defer(() -> { // Check for duplicate tool names - if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolRegistration.tool().name()))) { + if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolSpecification.tool().name()))) { return Mono - .error(new McpError("Tool with name '" + toolRegistration.tool().name() + "' already exists")); + .error(new McpError("Tool with name '" + toolSpecification.tool().name() + "' already exists")); } - this.tools.add(toolRegistration); - logger.debug("Added tool handler: {}", toolRegistration.tool().name()); + this.tools.add(toolSpecification); + logger.debug("Added tool handler: {}", toolSpecification.tool().name()); if (this.serverCapabilities.tools().listChanged()) { return notifyToolsListChanged(); @@ -609,7 +620,7 @@ public Mono notifyToolsListChanged() { private McpServerSession.RequestHandler toolsListRequestHandler() { return (exchange, params) -> { - List tools = this.tools.stream().map(McpServerFeatures.AsyncToolRegistration::tool).toList(); + List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); return Mono.just(new McpSchema.ListToolsResult(tools, null)); }; @@ -621,15 +632,15 @@ private McpServerSession.RequestHandler toolsCallRequestHandler( new TypeReference() { }); - Optional toolRegistration = this.tools.stream() + Optional toolSpecification = this.tools.stream() .filter(tr -> callToolRequest.name().equals(tr.tool().name())) .findAny(); - if (toolRegistration.isEmpty()) { + if (toolSpecification.isEmpty()) { return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); } - return toolRegistration.map(tool -> tool.call().apply(callToolRequest.arguments())) + return toolSpecification.map(tool -> tool.call().apply(exchange, callToolRequest.arguments())) .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); }; } @@ -642,9 +653,21 @@ private McpServerSession.RequestHandler toolsCallRequestHandler( * Add a new resource handler at runtime. * @param resourceHandler The resource handler to add * @return Mono that completes when clients have been notified of the change + * @deprecated This method will be removed in 0.9.0. Use + * {@link #addResource(McpServerFeatures.AsyncResourceSpecification)}. */ + @Deprecated public Mono addResource(McpServerFeatures.AsyncResourceRegistration resourceHandler) { - if (resourceHandler == null || resourceHandler.resource() == null) { + return this.addResource(resourceHandler.toSpecification()); + } + + /** + * Add a new resource handler at runtime. + * @param resourceSpecification The resource handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceSpecification) { + if (resourceSpecification == null || resourceSpecification.resource() == null) { return Mono.error(new McpError("Resource must not be null")); } @@ -653,11 +676,11 @@ public Mono addResource(McpServerFeatures.AsyncResourceRegistration resour } return Mono.defer(() -> { - if (this.resources.putIfAbsent(resourceHandler.resource().uri(), resourceHandler) != null) { + if (this.resources.putIfAbsent(resourceSpecification.resource().uri(), resourceSpecification) != null) { return Mono.error(new McpError( - "Resource with URI '" + resourceHandler.resource().uri() + "' already exists")); + "Resource with URI '" + resourceSpecification.resource().uri() + "' already exists")); } - logger.debug("Added resource handler: {}", resourceHandler.resource().uri()); + logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); if (this.serverCapabilities.resources().listChanged()) { return notifyResourcesListChanged(); } @@ -679,7 +702,7 @@ public Mono removeResource(String resourceUri) { } return Mono.defer(() -> { - McpServerFeatures.AsyncResourceRegistration removed = this.resources.remove(resourceUri); + McpServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri); if (removed != null) { logger.debug("Removed resource handler: {}", resourceUri); if (this.serverCapabilities.resources().listChanged()) { @@ -696,8 +719,6 @@ public Mono removeResource(String resourceUri) { * @return A Mono that completes when all clients have been notified */ public Mono notifyResourcesListChanged() { - McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification( - McpSchema.JSONRPC_VERSION, McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); } @@ -705,7 +726,7 @@ private McpServerSession.RequestHandler resources return (exchange, params) -> { var resourceList = this.resources.values() .stream() - .map(McpServerFeatures.AsyncResourceRegistration::resource) + .map(McpServerFeatures.AsyncResourceSpecification::resource) .toList(); return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); }; @@ -723,9 +744,9 @@ private McpServerSession.RequestHandler resourcesR new TypeReference() { }); var resourceUri = resourceRequest.uri(); - McpServerFeatures.AsyncResourceRegistration registration = this.resources.get(resourceUri); + McpServerFeatures.AsyncResourceSpecification registration = this.resources.get(resourceUri); if (registration != null) { - return registration.readHandler().apply(resourceRequest); + return registration.readHandler().apply(exchange, resourceRequest); } return Mono.error(new McpError("Resource not found: " + resourceUri)); }; @@ -739,9 +760,21 @@ private McpServerSession.RequestHandler resourcesR * Add a new prompt handler at runtime. * @param promptRegistration The prompt handler to add * @return Mono that completes when clients have been notified of the change + * @deprecated This method will be removed in 0.9.0. Use + * {@link #addPrompt(McpServerFeatures.AsyncPromptSpecification)}. */ + @Deprecated public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegistration) { - if (promptRegistration == null) { + return this.addPrompt(promptRegistration.toSpecification()); + } + + /** + * Add a new prompt handler at runtime. + * @param promptSpecification The prompt handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) { + if (promptSpecification == null) { return Mono.error(new McpError("Prompt registration must not be null")); } if (this.serverCapabilities.prompts() == null) { @@ -749,14 +782,14 @@ public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegi } return Mono.defer(() -> { - McpServerFeatures.AsyncPromptRegistration registration = this.prompts - .putIfAbsent(promptRegistration.prompt().name(), promptRegistration); - if (registration != null) { + McpServerFeatures.AsyncPromptSpecification specification = this.prompts + .putIfAbsent(promptSpecification.prompt().name(), promptSpecification); + if (specification != null) { return Mono.error(new McpError( - "Prompt with name '" + promptRegistration.prompt().name() + "' already exists")); + "Prompt with name '" + promptSpecification.prompt().name() + "' already exists")); } - logger.debug("Added prompt handler: {}", promptRegistration.prompt().name()); + logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); // Servers that declared the listChanged capability SHOULD send a // notification, @@ -782,7 +815,7 @@ public Mono removePrompt(String promptName) { } return Mono.defer(() -> { - McpServerFeatures.AsyncPromptRegistration removed = this.prompts.remove(promptName); + McpServerFeatures.AsyncPromptSpecification removed = this.prompts.remove(promptName); if (removed != null) { logger.debug("Removed prompt handler: {}", promptName); @@ -814,7 +847,7 @@ private McpServerSession.RequestHandler promptsList var promptList = this.prompts.values() .stream() - .map(McpServerFeatures.AsyncPromptRegistration::prompt) + .map(McpServerFeatures.AsyncPromptSpecification::prompt) .toList(); return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); @@ -828,12 +861,12 @@ private McpServerSession.RequestHandler promptsGetReq }); // Implement prompt retrieval logic here - McpServerFeatures.AsyncPromptRegistration registration = this.prompts.get(promptRequest.name()); - if (registration == null) { + McpServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name()); + if (specification == null) { return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); } - return registration.promptHandler().apply(promptRequest); + return specification.promptHandler().apply(exchange, promptRequest); }; } @@ -938,13 +971,13 @@ private static final class LegacyAsyncServer extends McpAsyncServer { /** * Thread-safe list of tool handlers that can be modified at runtime. */ - private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); + private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); - private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); + private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); - private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); + private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; @@ -1170,7 +1203,7 @@ public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistrati .error(new McpError("Tool with name '" + toolRegistration.tool().name() + "' already exists")); } - this.tools.add(toolRegistration); + this.tools.add(toolRegistration.toSpecification()); logger.debug("Added tool handler: {}", toolRegistration.tool().name()); if (this.serverCapabilities.tools().listChanged()) { @@ -1217,7 +1250,7 @@ public Mono notifyToolsListChanged() { private DefaultMcpSession.RequestHandler toolsListRequestHandler() { return params -> { - List tools = this.tools.stream().map(McpServerFeatures.AsyncToolRegistration::tool).toList(); + List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); return Mono.just(new McpSchema.ListToolsResult(tools, null)); }; @@ -1229,7 +1262,7 @@ private DefaultMcpSession.RequestHandler toolsCallRequestHandler new TypeReference() { }); - Optional toolRegistration = this.tools.stream() + Optional toolRegistration = this.tools.stream() .filter(tr -> callToolRequest.name().equals(tr.tool().name())) .findAny(); @@ -1237,7 +1270,7 @@ private DefaultMcpSession.RequestHandler toolsCallRequestHandler return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); } - return toolRegistration.map(tool -> tool.call().apply(callToolRequest.arguments())) + return toolRegistration.map(tool -> tool.call().apply(null, callToolRequest.arguments())) .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); }; } @@ -1261,7 +1294,8 @@ public Mono addResource(McpServerFeatures.AsyncResourceRegistration resour } return Mono.defer(() -> { - if (this.resources.putIfAbsent(resourceHandler.resource().uri(), resourceHandler) != null) { + if (this.resources.putIfAbsent(resourceHandler.resource().uri(), + resourceHandler.toSpecification()) != null) { return Mono.error(new McpError( "Resource with URI '" + resourceHandler.resource().uri() + "' already exists")); } @@ -1287,7 +1321,7 @@ public Mono removeResource(String resourceUri) { } return Mono.defer(() -> { - McpServerFeatures.AsyncResourceRegistration removed = this.resources.remove(resourceUri); + McpServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri); if (removed != null) { logger.debug("Removed resource handler: {}", resourceUri); if (this.serverCapabilities.resources().listChanged()) { @@ -1311,7 +1345,7 @@ private DefaultMcpSession.RequestHandler resource return params -> { var resourceList = this.resources.values() .stream() - .map(McpServerFeatures.AsyncResourceRegistration::resource) + .map(McpServerFeatures.AsyncResourceSpecification::resource) .toList(); return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); }; @@ -1328,9 +1362,9 @@ private DefaultMcpSession.RequestHandler resources new TypeReference() { }); var resourceUri = resourceRequest.uri(); - McpServerFeatures.AsyncResourceRegistration registration = this.resources.get(resourceUri); + McpServerFeatures.AsyncResourceSpecification registration = this.resources.get(resourceUri); if (registration != null) { - return registration.readHandler().apply(resourceRequest); + return registration.readHandler().apply(null, resourceRequest); } return Mono.error(new McpError("Resource not found: " + resourceUri)); }; @@ -1354,8 +1388,8 @@ public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegi } return Mono.defer(() -> { - McpServerFeatures.AsyncPromptRegistration registration = this.prompts - .putIfAbsent(promptRegistration.prompt().name(), promptRegistration); + McpServerFeatures.AsyncPromptSpecification registration = this.prompts + .putIfAbsent(promptRegistration.prompt().name(), promptRegistration.toSpecification()); if (registration != null) { return Mono.error(new McpError( "Prompt with name '" + promptRegistration.prompt().name() + "' already exists")); @@ -1387,7 +1421,7 @@ public Mono removePrompt(String promptName) { } return Mono.defer(() -> { - McpServerFeatures.AsyncPromptRegistration removed = this.prompts.remove(promptName); + McpServerFeatures.AsyncPromptSpecification removed = this.prompts.remove(promptName); if (removed != null) { logger.debug("Removed prompt handler: {}", promptName); @@ -1419,7 +1453,7 @@ private DefaultMcpSession.RequestHandler promptsLis var promptList = this.prompts.values() .stream() - .map(McpServerFeatures.AsyncPromptRegistration::prompt) + .map(McpServerFeatures.AsyncPromptSpecification::prompt) .toList(); return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); @@ -1433,12 +1467,12 @@ private DefaultMcpSession.RequestHandler promptsGetRe }); // Implement prompt retrieval logic here - McpServerFeatures.AsyncPromptRegistration registration = this.prompts.get(promptRequest.name()); + McpServerFeatures.AsyncPromptSpecification registration = this.prompts.get(promptRequest.name()); if (registration == null) { return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); } - return registration.promptHandler().apply(promptRequest); + return registration.promptHandler().apply(null, promptRequest); }; } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index 840631a0..7c4eb6dc 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -133,8 +133,8 @@ static SyncSpec sync(ServerMcpTransport transport) { return new SyncSpec(transport); } - static SyncSpec sync(McpServerTransportProvider transportProvider) { - return new SyncSpec(transportProvider); + static SyncSpecification sync(McpServerTransportProvider transportProvider) { + return new SyncSpecification(transportProvider); } /** @@ -152,13 +152,802 @@ static AsyncSpec async(ServerMcpTransport transport) { return new AsyncSpec(transport); } - static AsyncSpec async(McpServerTransportProvider transportProvider) { - return new AsyncSpec(transportProvider); + static AsyncSpecification async(McpServerTransportProvider transportProvider) { + return new AsyncSpecification(transportProvider); } /** * Asynchronous server specification. */ + class AsyncSpecification { + + private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", + "1.0.0"); + + private final McpServerTransportProvider transportProvider; + + private ObjectMapper objectMapper; + + private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; + + private McpSchema.ServerCapabilities serverCapabilities; + + /** + * The Model Context Protocol (MCP) allows servers to expose tools that can be + * invoked by language models. Tools enable models to interact with external + * systems, such as querying databases, calling APIs, or performing computations. + * Each tool is uniquely identified by a name and includes metadata describing its + * schema. + */ + private final List tools = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resources to clients. Resources allow servers to share data that + * provides context to language models, such as files, database schemas, or + * application-specific information. Each resource is uniquely identified by a + * URI. + */ + private final Map resources = new HashMap<>(); + + private final List resourceTemplates = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose prompt templates to clients. Prompts allow servers to provide structured + * messages and instructions for interacting with language models. Clients can + * discover available prompts, retrieve their contents, and provide arguments to + * customize them. + */ + private final Map prompts = new HashMap<>(); + + private final List, Mono>> rootsChangeHandlers = new ArrayList<>(); + + private AsyncSpecification(McpServerTransportProvider transportProvider) { + Assert.notNull(transportProvider, "Transport provider must not be null"); + this.transportProvider = transportProvider; + } + + /** + * Sets the server implementation information that will be shared with clients + * during connection initialization. This helps with version compatibility, + * debugging, and server identification. + * @param serverInfo The server implementation details including name and version. + * Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverInfo is null + */ + public AsyncSpecification serverInfo(McpSchema.Implementation serverInfo) { + Assert.notNull(serverInfo, "Server info must not be null"); + this.serverInfo = serverInfo; + return this; + } + + /** + * Sets the server implementation information using name and version strings. This + * is a convenience method alternative to + * {@link #serverInfo(McpSchema.Implementation)}. + * @param name The server name. Must not be null or empty. + * @param version The server version. Must not be null or empty. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if name or version is null or empty + * @see #serverInfo(McpSchema.Implementation) + */ + public AsyncSpecification serverInfo(String name, String version) { + Assert.hasText(name, "Name must not be null or empty"); + Assert.hasText(version, "Version must not be null or empty"); + this.serverInfo = new McpSchema.Implementation(name, version); + return this; + } + + /** + * Sets the server capabilities that will be advertised to clients during + * connection initialization. Capabilities define what features the server + * supports, such as: + *
      + *
    • Tool execution + *
    • Resource access + *
    • Prompt handling + *
    • Streaming responses + *
    • Batch operations + *
    + * @param serverCapabilities The server capabilities configuration. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverCapabilities is null + */ + public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + this.serverCapabilities = serverCapabilities; + return this; + } + + /** + * Adds a single tool with its implementation handler to the server. This is a + * convenience method for registering individual tools without creating a + * {@link McpServerFeatures.AsyncToolSpecification} explicitly. + * + *

    + * Example usage:

    {@code
    +		 * .tool(
    +		 *     new Tool("calculator", "Performs calculations", schema),
    +		 *     args -> Mono.just(new CallToolResult("Result: " + calculate(args)))
    +		 * )
    +		 * }
    + * @param tool The tool definition including name, description, and schema. Must + * not be null. + * @param handler The function that implements the tool's logic. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if tool or handler is null + */ + public AsyncSpecification tool(McpSchema.Tool tool, + BiFunction, Mono> handler) { + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(handler, "Handler must not be null"); + + this.tools.add(new McpServerFeatures.AsyncToolSpecification(tool, handler)); + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using a List. This method + * is useful when tools are dynamically generated or loaded from a configuration + * source. + * @param toolSpecifications The list of tool specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + * @see #tools(McpServerFeatures.AsyncToolSpecification...) + */ + public AsyncSpecification tools(List toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + this.tools.addAll(toolSpecifications); + return this; + } + + /** + * Adds multiple tools with their handlers to the server using varargs. This + * method provides a convenient way to register multiple tools inline. + * + *

    + * Example usage:

    {@code
    +		 * .tools(
    +		 *     new McpServerFeatures.AsyncToolSpecification(calculatorTool, calculatorHandler),
    +		 *     new McpServerFeatures.AsyncToolSpecification(weatherTool, weatherHandler),
    +		 *     new McpServerFeatures.AsyncToolSpecification(fileManagerTool, fileManagerHandler)
    +		 * )
    +		 * }
    + * @param toolSpecifications The tool specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + * @see #tools(List) + */ + public AsyncSpecification tools(McpServerFeatures.AsyncToolSpecification... toolSpecifications) { + for (McpServerFeatures.AsyncToolSpecification tool : toolSpecifications) { + this.tools.add(tool); + } + return this; + } + + /** + * Registers multiple resources with their handlers using a Map. This method is + * useful when resources are dynamically generated or loaded from a configuration + * source. + * @param resourceSpecifications Map of resource name to specification. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpServerFeatures.AsyncResourceSpecification...) + */ + public AsyncSpecification resources( + Map resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); + this.resources.putAll(resourceSpecifications); + return this; + } + + /** + * Registers multiple resources with their handlers using a List. This method is + * useful when resources need to be added in bulk from a collection. + * @param resourceSpecifications List of resource specifications. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpServerFeatures.AsyncResourceSpecification...) + */ + public AsyncSpecification resources(List resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (McpServerFeatures.AsyncResourceSpecification resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Registers multiple resources with their handlers using varargs. This method + * provides a convenient way to register multiple resources inline. + * + *

    + * Example usage:

    {@code
    +		 * .resources(
    +		 *     new McpServerFeatures.AsyncResourceSpecification(fileResource, fileHandler),
    +		 *     new McpServerFeatures.AsyncResourceSpecification(dbResource, dbHandler),
    +		 *     new McpServerFeatures.AsyncResourceSpecification(apiResource, apiHandler)
    +		 * )
    +		 * }
    + * @param resourceSpecifications The resource specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + */ + public AsyncSpecification resources(McpServerFeatures.AsyncResourceSpecification... resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (McpServerFeatures.AsyncResourceSpecification resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Sets the resource templates that define patterns for dynamic resource access. + * Templates use URI patterns with placeholders that can be filled at runtime. + * + *

    + * Example usage:

    {@code
    +		 * .resourceTemplates(
    +		 *     new ResourceTemplate("file://{path}", "Access files by path"),
    +		 *     new ResourceTemplate("db://{table}/{id}", "Access database records")
    +		 * )
    +		 * }
    + * @param resourceTemplates List of resource templates. If null, clears existing + * templates. + * @return This builder instance for method chaining + * @see #resourceTemplates(ResourceTemplate...) + */ + public AsyncSpecification resourceTemplates(List resourceTemplates) { + this.resourceTemplates.addAll(resourceTemplates); + return this; + } + + /** + * Sets the resource templates using varargs for convenience. This is an + * alternative to {@link #resourceTemplates(List)}. + * @param resourceTemplates The resource templates to set. + * @return This builder instance for method chaining + * @see #resourceTemplates(List) + */ + public AsyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + for (ResourceTemplate resourceTemplate : resourceTemplates) { + this.resourceTemplates.add(resourceTemplate); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using a Map. This method is + * useful when prompts are dynamically generated or loaded from a configuration + * source. + * + *

    + * Example usage:

    {@code
    +		 * .prompts(Map.of("analysis", new McpServerFeatures.AsyncPromptSpecification(
    +		 *     new Prompt("analysis", "Code analysis template"),
    +		 *     request -> Mono.just(new GetPromptResult(generateAnalysisPrompt(request)))
    +		 * )));
    +		 * }
    + * @param prompts Map of prompt name to specification. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public AsyncSpecification prompts(Map prompts) { + this.prompts.putAll(prompts); + return this; + } + + /** + * Registers multiple prompts with their handlers using a List. This method is + * useful when prompts need to be added in bulk from a collection. + * @param prompts List of prompt specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + * @see #prompts(McpServerFeatures.AsyncPromptSpecification...) + */ + public AsyncSpecification prompts(List prompts) { + for (McpServerFeatures.AsyncPromptSpecification prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using varargs. This method + * provides a convenient way to register multiple prompts inline. + * + *

    + * Example usage:

    {@code
    +		 * .prompts(
    +		 *     new McpServerFeatures.AsyncPromptSpecification(analysisPrompt, analysisHandler),
    +		 *     new McpServerFeatures.AsyncPromptSpecification(summaryPrompt, summaryHandler),
    +		 *     new McpServerFeatures.AsyncPromptSpecification(reviewPrompt, reviewHandler)
    +		 * )
    +		 * }
    + * @param prompts The prompt specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public AsyncSpecification prompts(McpServerFeatures.AsyncPromptSpecification... prompts) { + for (McpServerFeatures.AsyncPromptSpecification prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers a consumer that will be notified when the list of roots changes. This + * is useful for updating resource availability dynamically, such as when new + * files are added or removed. + * @param handler The handler to register. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumer is null + */ + public AsyncSpecification rootsChangeHandler( + BiFunction, Mono> handler) { + Assert.notNull(handler, "Consumer must not be null"); + this.rootsChangeHandlers.add(handler); + return this; + } + + /** + * Registers multiple consumers that will be notified when the list of roots + * changes. This method is useful when multiple consumers need to be registered at + * once. + * @param handlers The list of handlers to register. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumers is null + */ + public AsyncSpecification rootsChangeHandlers( + List, Mono>> handlers) { + Assert.notNull(handlers, "Handlers list must not be null"); + this.rootsChangeHandlers.addAll(handlers); + return this; + } + + /** + * Registers multiple consumers that will be notified when the list of roots + * changes using varargs. This method provides a convenient way to register + * multiple consumers inline. + * @param handlers The handlers to register. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumers is null + */ + public AsyncSpecification rootsChangeHandlers( + @SuppressWarnings("unchecked") BiFunction, Mono>... handlers) { + return this.rootsChangeHandlers(Arrays.asList(handlers)); + } + + public AsyncSpecification objectMapper(ObjectMapper objectMapper) { + this.objectMapper = objectMapper; + return this; + } + + /** + * Builds an asynchronous MCP server that provides non-blocking operations. + * @return A new instance of {@link McpAsyncServer} configured with this builder's + * settings + */ + public McpAsyncServer build() { + var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, + this.resources, this.resourceTemplates, this.prompts, this.rootsChangeHandlers); + var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + return new McpAsyncServer(this.transportProvider, mapper, features); + } + + } + + /** + * Synchronous server specification. + */ + class SyncSpecification { + + private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", + "1.0.0"); + + private final McpServerTransportProvider transportProvider; + + private ObjectMapper objectMapper; + + private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; + + private McpSchema.ServerCapabilities serverCapabilities; + + /** + * The Model Context Protocol (MCP) allows servers to expose tools that can be + * invoked by language models. Tools enable models to interact with external + * systems, such as querying databases, calling APIs, or performing computations. + * Each tool is uniquely identified by a name and includes metadata describing its + * schema. + */ + private final List tools = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resources to clients. Resources allow servers to share data that + * provides context to language models, such as files, database schemas, or + * application-specific information. Each resource is uniquely identified by a + * URI. + */ + private final Map resources = new HashMap<>(); + + private final List resourceTemplates = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose prompt templates to clients. Prompts allow servers to provide structured + * messages and instructions for interacting with language models. Clients can + * discover available prompts, retrieve their contents, and provide arguments to + * customize them. + */ + private final Map prompts = new HashMap<>(); + + private final List>> rootsChangeHandlers = new ArrayList<>(); + + private SyncSpecification(McpServerTransportProvider transportProvider) { + Assert.notNull(transportProvider, "Transport provider must not be null"); + this.transportProvider = transportProvider; + } + + /** + * Sets the server implementation information that will be shared with clients + * during connection initialization. This helps with version compatibility, + * debugging, and server identification. + * @param serverInfo The server implementation details including name and version. + * Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverInfo is null + */ + public SyncSpecification serverInfo(McpSchema.Implementation serverInfo) { + Assert.notNull(serverInfo, "Server info must not be null"); + this.serverInfo = serverInfo; + return this; + } + + /** + * Sets the server implementation information using name and version strings. This + * is a convenience method alternative to + * {@link #serverInfo(McpSchema.Implementation)}. + * @param name The server name. Must not be null or empty. + * @param version The server version. Must not be null or empty. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if name or version is null or empty + * @see #serverInfo(McpSchema.Implementation) + */ + public SyncSpecification serverInfo(String name, String version) { + Assert.hasText(name, "Name must not be null or empty"); + Assert.hasText(version, "Version must not be null or empty"); + this.serverInfo = new McpSchema.Implementation(name, version); + return this; + } + + /** + * Sets the server capabilities that will be advertised to clients during + * connection initialization. Capabilities define what features the server + * supports, such as: + *
      + *
    • Tool execution + *
    • Resource access + *
    • Prompt handling + *
    • Streaming responses + *
    • Batch operations + *
    + * @param serverCapabilities The server capabilities configuration. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverCapabilities is null + */ + public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + this.serverCapabilities = serverCapabilities; + return this; + } + + /** + * Adds a single tool with its implementation handler to the server. This is a + * convenience method for registering individual tools without creating a + * {@link McpServerFeatures.SyncToolSpecification} explicitly. + * + *

    + * Example usage:

    {@code
    +		 * .tool(
    +		 *     new Tool("calculator", "Performs calculations", schema),
    +		 *     args -> new CallToolResult("Result: " + calculate(args))
    +		 * )
    +		 * }
    + * @param tool The tool definition including name, description, and schema. Must + * not be null. + * @param handler The function that implements the tool's logic. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if tool or handler is null + */ + public SyncSpecification tool(McpSchema.Tool tool, + BiFunction, McpSchema.CallToolResult> handler) { + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(handler, "Handler must not be null"); + + this.tools.add(new McpServerFeatures.SyncToolSpecification(tool, handler)); + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using a List. This method + * is useful when tools are dynamically generated or loaded from a configuration + * source. + * @param toolSpecifications The list of tool specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + * @see #tools(McpServerFeatures.SyncToolSpecification...) + */ + public SyncSpecification tools(List toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + this.tools.addAll(toolSpecifications); + return this; + } + + /** + * Adds multiple tools with their handlers to the server using varargs. This + * method provides a convenient way to register multiple tools inline. + * + *

    + * Example usage:

    {@code
    +		 * .tools(
    +		 *     new ToolSpecification(calculatorTool, calculatorHandler),
    +		 *     new ToolSpecification(weatherTool, weatherHandler),
    +		 *     new ToolSpecification(fileManagerTool, fileManagerHandler)
    +		 * )
    +		 * }
    + * @param toolSpecifications The tool specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + * @see #tools(List) + */ + public SyncSpecification tools(McpServerFeatures.SyncToolSpecification... toolSpecifications) { + for (McpServerFeatures.SyncToolSpecification tool : toolSpecifications) { + this.tools.add(tool); + } + return this; + } + + /** + * Registers multiple resources with their handlers using a Map. This method is + * useful when resources are dynamically generated or loaded from a configuration + * source. + * @param resourceSpecifications Map of resource name to specification. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpServerFeatures.SyncResourceSpecification...) + */ + public SyncSpecification resources( + Map resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); + this.resources.putAll(resourceSpecifications); + return this; + } + + /** + * Registers multiple resources with their handlers using a List. This method is + * useful when resources need to be added in bulk from a collection. + * @param resourceSpecifications List of resource specifications. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpServerFeatures.SyncResourceSpecification...) + */ + public SyncSpecification resources(List resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (McpServerFeatures.SyncResourceSpecification resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Registers multiple resources with their handlers using varargs. This method + * provides a convenient way to register multiple resources inline. + * + *

    + * Example usage:

    {@code
    +		 * .resources(
    +		 *     new ResourceSpecification(fileResource, fileHandler),
    +		 *     new ResourceSpecification(dbResource, dbHandler),
    +		 *     new ResourceSpecification(apiResource, apiHandler)
    +		 * )
    +		 * }
    + * @param resourceSpecifications The resource specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + */ + public SyncSpecification resources(McpServerFeatures.SyncResourceSpecification... resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (McpServerFeatures.SyncResourceSpecification resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Sets the resource templates that define patterns for dynamic resource access. + * Templates use URI patterns with placeholders that can be filled at runtime. + * + *

    + * Example usage:

    {@code
    +		 * .resourceTemplates(
    +		 *     new ResourceTemplate("file://{path}", "Access files by path"),
    +		 *     new ResourceTemplate("db://{table}/{id}", "Access database records")
    +		 * )
    +		 * }
    + * @param resourceTemplates List of resource templates. If null, clears existing + * templates. + * @return This builder instance for method chaining + * @see #resourceTemplates(ResourceTemplate...) + */ + public SyncSpecification resourceTemplates(List resourceTemplates) { + this.resourceTemplates.addAll(resourceTemplates); + return this; + } + + /** + * Sets the resource templates using varargs for convenience. This is an + * alternative to {@link #resourceTemplates(List)}. + * @param resourceTemplates The resource templates to set. + * @return This builder instance for method chaining + * @see #resourceTemplates(List) + */ + public SyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + for (ResourceTemplate resourceTemplate : resourceTemplates) { + this.resourceTemplates.add(resourceTemplate); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using a Map. This method is + * useful when prompts are dynamically generated or loaded from a configuration + * source. + * + *

    + * Example usage:

    {@code
    +		 * Map prompts = new HashMap<>();
    +		 * prompts.put("analysis", new PromptSpecification(
    +		 *     new Prompt("analysis", "Code analysis template"),
    +		 *     request -> new GetPromptResult(generateAnalysisPrompt(request))
    +		 * ));
    +		 * .prompts(prompts)
    +		 * }
    + * @param prompts Map of prompt name to specification. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public SyncSpecification prompts(Map prompts) { + this.prompts.putAll(prompts); + return this; + } + + /** + * Registers multiple prompts with their handlers using a List. This method is + * useful when prompts need to be added in bulk from a collection. + * @param prompts List of prompt specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + * @see #prompts(McpServerFeatures.SyncPromptSpecification...) + */ + public SyncSpecification prompts(List prompts) { + for (McpServerFeatures.SyncPromptSpecification prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using varargs. This method + * provides a convenient way to register multiple prompts inline. + * + *

    + * Example usage:

    {@code
    +		 * .prompts(
    +		 *     new PromptSpecification(analysisPrompt, analysisHandler),
    +		 *     new PromptSpecification(summaryPrompt, summaryHandler),
    +		 *     new PromptSpecification(reviewPrompt, reviewHandler)
    +		 * )
    +		 * }
    + * @param prompts The prompt specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public SyncSpecification prompts(McpServerFeatures.SyncPromptSpecification... prompts) { + for (McpServerFeatures.SyncPromptSpecification prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers a consumer that will be notified when the list of roots changes. This + * is useful for updating resource availability dynamically, such as when new + * files are added or removed. + * @param handler The handler to register. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumer is null + * @deprecated This method will be removed in 0.9.0. Use + * {@link #rootsChangeHandler(BiConsumer)}. + */ + public SyncSpecification rootsChangeHandler(BiConsumer> handler) { + Assert.notNull(handler, "Consumer must not be null"); + this.rootsChangeHandlers.add(handler); + return this; + } + + /** + * Registers multiple consumers that will be notified when the list of roots + * changes. This method is useful when multiple consumers need to be registered at + * once. + * @param handlers The list of handlers to register. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumers is null + * @deprecated This method will be removed in 0.9.0. Use + * {@link #rootsChangeHandlers(List)}. + */ + public SyncSpecification rootsChangeHandlers( + List>> handlers) { + Assert.notNull(handlers, "Handlers list must not be null"); + this.rootsChangeHandlers.addAll(handlers); + return this; + } + + /** + * Registers multiple consumers that will be notified when the list of roots + * changes using varargs. This method provides a convenient way to register + * multiple consumers inline. + * @param handlers The handlers to register. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumers is null + * @deprecated This method will * be removed in 0.9.0. Use + * {@link #rootsChangeHandlers(BiConsumer[])}. + */ + public SyncSpecification rootsChangeHandlers( + BiConsumer>... handlers) { + return this.rootsChangeHandlers(List.of(handlers)); + } + + public SyncSpecification objectMapper(ObjectMapper objectMapper) { + this.objectMapper = objectMapper; + return this; + } + + /** + * Builds a synchronous MCP server that provides blocking operations. + * @return A new instance of {@link McpSyncServer} configured with this builder's + * settings + */ + public McpSyncServer build() { + McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, + this.tools, this.resources, this.resourceTemplates, this.prompts, this.rootsChangeHandlers); + McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures); + var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures); + + return new McpSyncServer(asyncServer); + } + + } + + /** + * Asynchronous server specification. + * + * @deprecated + */ + @Deprecated class AsyncSpec { private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", @@ -166,8 +955,6 @@ class AsyncSpec { private final ServerMcpTransport transport; - private final McpServerTransportProvider transportProvider; - private ObjectMapper objectMapper; private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; @@ -203,18 +990,11 @@ class AsyncSpec { */ private final Map prompts = new HashMap<>(); - private final List, Mono>> rootsChangeHandlers = new ArrayList<>(); - - private AsyncSpec(McpServerTransportProvider transportProvider) { - Assert.notNull(transportProvider, "Transport provider must not be null"); - this.transport = null; - this.transportProvider = transportProvider; - } + private final List, Mono>> rootsChangeConsumers = new ArrayList<>(); private AsyncSpec(ServerMcpTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; - this.transportProvider = null; } /** @@ -495,19 +1275,10 @@ public AsyncSpec prompts(McpServerFeatures.AsyncPromptRegistration... prompts) { * @param consumer The consumer to register. Must not be null. * @return This builder instance for method chaining * @throws IllegalArgumentException if consumer is null - * @deprecated This method will be removed in 0.9.0. Use - * {@link #rootsChangeHandler(BiFunction)} instead. */ - @Deprecated public AsyncSpec rootsChangeConsumer(Function, Mono> consumer) { Assert.notNull(consumer, "Consumer must not be null"); - return this.rootsChangeHandler((exchange, roots) -> consumer.apply(roots)); - } - - public AsyncSpec rootsChangeHandler( - BiFunction, Mono> handler) { - Assert.notNull(handler, "Consumer must not be null"); - this.rootsChangeHandlers.add(handler); + this.rootsChangeConsumers.add(consumer); return this; } @@ -518,22 +1289,10 @@ public AsyncSpec rootsChangeHandler( * @param consumers The list of consumers to register. Must not be null. * @return This builder instance for method chaining * @throws IllegalArgumentException if consumers is null - * @deprecated This method will be removed in 0.9.0. Use - * {@link #rootsChangeHandlers(List)} instead. */ - @Deprecated public AsyncSpec rootsChangeConsumers(List, Mono>> consumers) { Assert.notNull(consumers, "Consumers list must not be null"); - return this.rootsChangeHandlers(consumers.stream() - .map(consumer -> (BiFunction, Mono>) ( - McpAsyncServerExchange exchange, List roots) -> consumer.apply(roots)) - .collect(Collectors.toList())); - } - - public AsyncSpec rootsChangeHandlers( - List, Mono>> handlers) { - Assert.notNull(handlers, "Handlers list must not be null"); - this.rootsChangeHandlers.addAll(handlers); + this.rootsChangeConsumers.addAll(consumers); return this; } @@ -544,22 +1303,12 @@ public AsyncSpec rootsChangeHandlers( * @param consumers The consumers to register. Must not be null. * @return This builder instance for method chaining * @throws IllegalArgumentException if consumers is null - * @deprecated This method will be removed in 0.9.0. Use - * {@link #rootsChangeHandlers(BiFunction...)} instead. */ - @Deprecated public AsyncSpec rootsChangeConsumers( @SuppressWarnings("unchecked") Function, Mono>... consumers) { - return this.rootsChangeConsumers(Arrays.asList(consumers)); - } - - public AsyncSpec rootsChangeHandlers( - @SuppressWarnings("unchecked") BiFunction, Mono>... handlers) { - return this.rootsChangeHandlers(Arrays.asList(handlers)); - } - - public AsyncSpec objectMapper(ObjectMapper objectMapper) { - this.objectMapper = objectMapper; + for (Function, Mono> consumer : consumers) { + this.rootsChangeConsumers.add(consumer); + } return this; } @@ -569,22 +1318,37 @@ public AsyncSpec objectMapper(ObjectMapper objectMapper) { * settings */ public McpAsyncServer build() { - var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, - this.resources, this.resourceTemplates, this.prompts, this.rootsChangeHandlers); - if (this.transportProvider != null) { - var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - return new McpAsyncServer(this.transportProvider, mapper, features); - } - else { - return new McpAsyncServer(this.transport, features); - } + var tools = this.tools.stream().map(McpServerFeatures.AsyncToolRegistration::toSpecification).toList(); + + var resources = this.resources.entrySet() + .stream() + .map(entry -> Map.entry(entry.getKey(), entry.getValue().toSpecification())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + var prompts = this.prompts.entrySet() + .stream() + .map(entry -> Map.entry(entry.getKey(), entry.getValue().toSpecification())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + var rootsChangeHandlers = this.rootsChangeConsumers.stream() + .map(consumer -> (BiFunction, Mono>) (exchange, + roots) -> consumer.apply(roots)) + .toList(); + + var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, tools, resources, + this.resourceTemplates, prompts, rootsChangeHandlers); + + return new McpAsyncServer(this.transport, features); } } /** * Synchronous server specification. + * + * @deprecated */ + @Deprecated class SyncSpec { private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", @@ -629,7 +1393,7 @@ class SyncSpec { */ private final Map prompts = new HashMap<>(); - private final List>> rootsChangeHandlers = new ArrayList<>(); + private final List>> rootsChangeConsumers = new ArrayList<>(); private SyncSpec(McpServerTransportProvider transportProvider) { Assert.notNull(transportProvider, "Transport provider must not be null"); @@ -923,18 +1687,10 @@ public SyncSpec prompts(McpServerFeatures.SyncPromptRegistration... prompts) { * @param consumer The consumer to register. Must not be null. * @return This builder instance for method chaining * @throws IllegalArgumentException if consumer is null - * @deprecated This method will be removed in 0.9.0. Use - * {@link #rootsChangeHandler(BiConsumer)}. */ - @Deprecated public SyncSpec rootsChangeConsumer(Consumer> consumer) { Assert.notNull(consumer, "Consumer must not be null"); - return this.rootsChangeHandler((exchange, roots) -> consumer.accept(roots)); - } - - public SyncSpec rootsChangeHandler(BiConsumer> handler) { - Assert.notNull(handler, "Consumer must not be null"); - this.rootsChangeHandlers.add(handler); + this.rootsChangeConsumers.add(consumer); return this; } @@ -945,21 +1701,10 @@ public SyncSpec rootsChangeHandler(BiConsumer>> consumers) { Assert.notNull(consumers, "Consumers list must not be null"); - return this.rootsChangeHandlers(consumers.stream() - .map(consumer -> (BiConsumer>) (exchange, roots) -> consumer - .accept(roots)) - .collect(Collectors.toList())); - } - - public SyncSpec rootsChangeHandlers(List>> handlers) { - Assert.notNull(handlers, "Handlers list must not be null"); - this.rootsChangeHandlers.addAll(handlers); + this.rootsChangeConsumers.addAll(consumers); return this; } @@ -970,20 +1715,11 @@ public SyncSpec rootsChangeHandlers(List>... consumers) { - return this.rootsChangeConsumers(Arrays.asList(consumers)); - } - - public SyncSpec rootsChangeHandlers(BiConsumer>... handlers) { - return this.rootsChangeHandlers(List.of(handlers)); - } - - public SyncSpec objectMapper(ObjectMapper objectMapper) { - this.objectMapper = objectMapper; + for (Consumer> consumer : consumers) { + this.rootsChangeConsumers.add(consumer); + } return this; } @@ -993,13 +1729,28 @@ public SyncSpec objectMapper(ObjectMapper objectMapper) { * settings */ public McpSyncServer build() { + var tools = this.tools.stream().map(McpServerFeatures.SyncToolRegistration::toSpecification).toList(); + + var resources = this.resources.entrySet() + .stream() + .map(entry -> Map.entry(entry.getKey(), entry.getValue().toSpecification())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + var prompts = this.prompts.entrySet() + .stream() + .map(entry -> Map.entry(entry.getKey(), entry.getValue().toSpecification())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + var rootsChangeHandlers = this.rootsChangeConsumers.stream() + .map(consumer -> (BiConsumer>) (exchange, roots) -> consumer + .accept(roots)) + .toList(); + McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, - this.tools, this.resources, this.resourceTemplates, this.prompts, this.rootsChangeHandlers); + tools, resources, this.resourceTemplates, prompts, rootsChangeHandlers); + McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures); - var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - var asyncServer = this.transportProvider != null - ? new McpAsyncServer(this.transportProvider, mapper, asyncFeatures) - : new McpAsyncServer(this.transport, asyncFeatures); + var asyncServer = new McpAsyncServer(this.transport, asyncFeatures); return new McpSyncServer(asyncServer); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java index 7e4e140f..d3c9ea63 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -30,34 +30,34 @@ public class McpServerFeatures { * * @param serverInfo The server implementation details * @param serverCapabilities The server capabilities - * @param tools The list of tool registrations - * @param resources The map of resource registrations + * @param tools The list of tool specifications + * @param resources The map of resource specifications * @param resourceTemplates The list of resource templates - * @param prompts The map of prompt registrations + * @param prompts The map of prompt specifications * @param rootsChangeConsumers The list of consumers that will be notified when the * roots list changes */ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, - List tools, Map resources, + List tools, Map resources, List resourceTemplates, - Map prompts, + Map prompts, List, Mono>> rootsChangeConsumers) { /** * Create an instance and validate the arguments. * @param serverInfo The server implementation details * @param serverCapabilities The server capabilities - * @param tools The list of tool registrations - * @param resources The map of resource registrations + * @param tools The list of tool specifications + * @param resources The map of resource specifications * @param resourceTemplates The list of resource templates - * @param prompts The map of prompt registrations + * @param prompts The map of prompt specifications * @param rootsChangeConsumers The list of consumers that will be notified when * the roots list changes */ Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, - List tools, Map resources, + List tools, Map resources, List resourceTemplates, - Map prompts, + Map prompts, List, Mono>> rootsChangeConsumers) { Assert.notNull(serverInfo, "Server info must not be null"); @@ -90,19 +90,19 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s * user. */ static Async fromSync(Sync syncSpec) { - List tools = new ArrayList<>(); + List tools = new ArrayList<>(); for (var tool : syncSpec.tools()) { - tools.add(AsyncToolRegistration.fromSync(tool)); + tools.add(AsyncToolSpecification.fromSync(tool)); } - Map resources = new HashMap<>(); + Map resources = new HashMap<>(); syncSpec.resources().forEach((key, resource) -> { - resources.put(key, AsyncResourceRegistration.fromSync(resource)); + resources.put(key, AsyncResourceSpecification.fromSync(resource)); }); - Map prompts = new HashMap<>(); + Map prompts = new HashMap<>(); syncSpec.prompts().forEach((key, prompt) -> { - prompts.put(key, AsyncPromptRegistration.fromSync(prompt)); + prompts.put(key, AsyncPromptSpecification.fromSync(prompt)); }); List, Mono>> rootChangeConsumers = new ArrayList<>(); @@ -123,36 +123,36 @@ static Async fromSync(Sync syncSpec) { * * @param serverInfo The server implementation details * @param serverCapabilities The server capabilities - * @param tools The list of tool registrations - * @param resources The map of resource registrations + * @param tools The list of tool specifications + * @param resources The map of resource specifications * @param resourceTemplates The list of resource templates - * @param prompts The map of prompt registrations + * @param prompts The map of prompt specifications * @param rootsChangeConsumers The list of consumers that will be notified when the * roots list changes */ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, - List tools, - Map resources, + List tools, + Map resources, List resourceTemplates, - Map prompts, + Map prompts, List>> rootsChangeConsumers) { /** * Create an instance and validate the arguments. * @param serverInfo The server implementation details * @param serverCapabilities The server capabilities - * @param tools The list of tool registrations - * @param resources The map of resource registrations + * @param tools The list of tool specifications + * @param resources The map of resource specifications * @param resourceTemplates The list of resource templates - * @param prompts The map of prompt registrations + * @param prompts The map of prompt specifications * @param rootsChangeConsumers The list of consumers that will be notified when * the roots list changes */ Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, - List tools, - Map resources, + List tools, + Map resources, List resourceTemplates, - Map prompts, + Map prompts, List>> rootsChangeConsumers) { Assert.notNull(serverInfo, "Server info must not be null"); @@ -178,6 +178,236 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se } + /** + * Specification of a tool with its asynchronous handler function. Tools are the + * primary way for MCP servers to expose functionality to AI models. Each tool + * represents a specific capability, such as: + *
      + *
    • Performing calculations + *
    • Accessing external APIs + *
    • Querying databases + *
    • Manipulating files + *
    • Executing system commands + *
    + * + *

    + * Example tool specification:

    {@code
    +	 * new McpServerFeatures.AsyncToolSpecification(
    +	 *     new Tool(
    +	 *         "calculator",
    +	 *         "Performs mathematical calculations",
    +	 *         new JsonSchemaObject()
    +	 *             .required("expression")
    +	 *             .property("expression", JsonSchemaType.STRING)
    +	 *     ),
    +	 *     (exchange, args) -> {
    +	 *         String expr = (String) args.get("expression");
    +	 *         return Mono.just(new CallToolResult("Result: " + evaluate(expr)));
    +	 *     }
    +	 * )
    +	 * }
    + * + * @param tool The tool definition including name, description, and parameter schema + * @param call The function that implements the tool's logic, receiving arguments and + * returning results + */ + public record AsyncToolSpecification(McpSchema.Tool tool, + BiFunction, Mono> call) { + + static AsyncToolSpecification fromSync(SyncToolSpecification tool) { + // FIXME: This is temporary, proper validation should be implemented + if (tool == null) { + return null; + } + return new AsyncToolSpecification(tool.tool(), + (exchange, map) -> Mono.fromCallable(() -> tool.call().apply(exchange, map)) + .subscribeOn(Schedulers.boundedElastic())); + } + } + + /** + * Specification of a resource with its asynchronous handler function. Resources + * provide context to AI models by exposing data such as: + *
      + *
    • File contents + *
    • Database records + *
    • API responses + *
    • System information + *
    • Application state + *
    + * + *

    + * Example resource specification:

    {@code
    +	 * new McpServerFeatures.AsyncResourceSpecification(
    +	 *     new Resource("docs", "Documentation files", "text/markdown"),
    +	 *     (exchange, request) -> {
    +	 *         String content = readFile(request.getPath());
    +	 *         return Mono.just(new ReadResourceResult(content));
    +	 *     }
    +	 * )
    +	 * }
    + * + * @param resource The resource definition including name, description, and MIME type + * @param readHandler The function that handles resource read requests + */ + public record AsyncResourceSpecification(McpSchema.Resource resource, + BiFunction> readHandler) { + + static AsyncResourceSpecification fromSync(SyncResourceSpecification resource) { + // FIXME: This is temporary, proper validation should be implemented + if (resource == null) { + return null; + } + return new AsyncResourceSpecification(resource.resource(), + (exchange, req) -> Mono.fromCallable(() -> resource.readHandler().apply(exchange, req)) + .subscribeOn(Schedulers.boundedElastic())); + } + } + + /** + * Specification of a prompt template with its asynchronous handler function. Prompts + * provide structured templates for AI model interactions, supporting: + *
      + *
    • Consistent message formatting + *
    • Parameter substitution + *
    • Context injection + *
    • Response formatting + *
    • Instruction templating + *
    + * + *

    + * Example prompt specification:

    {@code
    +	 * new McpServerFeatures.AsyncPromptSpecification(
    +	 *     new Prompt("analyze", "Code analysis template"),
    +	 *     (exchange, request) -> {
    +	 *         String code = request.getArguments().get("code");
    +	 *         return Mono.just(new GetPromptResult(
    +	 *             "Analyze this code:\n\n" + code + "\n\nProvide feedback on:"
    +	 *         ));
    +	 *     }
    +	 * )
    +	 * }
    + * + * @param prompt The prompt definition including name and description + * @param promptHandler The function that processes prompt requests and returns + * formatted templates + */ + public record AsyncPromptSpecification(McpSchema.Prompt prompt, + BiFunction> promptHandler) { + + static AsyncPromptSpecification fromSync(SyncPromptSpecification prompt) { + // FIXME: This is temporary, proper validation should be implemented + if (prompt == null) { + return null; + } + return new AsyncPromptSpecification(prompt.prompt(), + (exchange, req) -> Mono.fromCallable(() -> prompt.promptHandler().apply(exchange, req)) + .subscribeOn(Schedulers.boundedElastic())); + } + } + + /** + * Specification of a tool with its synchronous handler function. Tools are the + * primary way for MCP servers to expose functionality to AI models. Each tool + * represents a specific capability, such as: + *
      + *
    • Performing calculations + *
    • Accessing external APIs + *
    • Querying databases + *
    • Manipulating files + *
    • Executing system commands + *
    + * + *

    + * Example tool specification:

    {@code
    +	 * new McpServerFeatures.SyncToolSpecification(
    +	 *     new Tool(
    +	 *         "calculator",
    +	 *         "Performs mathematical calculations",
    +	 *         new JsonSchemaObject()
    +	 *             .required("expression")
    +	 *             .property("expression", JsonSchemaType.STRING)
    +	 *     ),
    +	 *     (exchange, args) -> {
    +	 *         String expr = (String) args.get("expression");
    +	 *         return new CallToolResult("Result: " + evaluate(expr));
    +	 *     }
    +	 * )
    +	 * }
    + * + * @param tool The tool definition including name, description, and parameter schema + * @param call The function that implements the tool's logic, receiving arguments and + * returning results + */ + public record SyncToolSpecification(McpSchema.Tool tool, + BiFunction, McpSchema.CallToolResult> call) { + } + + /** + * Specification of a resource with its synchronous handler function. Resources + * provide context to AI models by exposing data such as: + *
      + *
    • File contents + *
    • Database records + *
    • API responses + *
    • System information + *
    • Application state + *
    + * + *

    + * Example resource specification:

    {@code
    +	 * new McpServerFeatures.SyncResourceSpecification(
    +	 *     new Resource("docs", "Documentation files", "text/markdown"),
    +	 *     (exchange, request) -> {
    +	 *         String content = readFile(request.getPath());
    +	 *         return new ReadResourceResult(content);
    +	 *     }
    +	 * )
    +	 * }
    + * + * @param resource The resource definition including name, description, and MIME type + * @param readHandler The function that handles resource read requests + */ + public record SyncResourceSpecification(McpSchema.Resource resource, + BiFunction readHandler) { + } + + /** + * Specification of a prompt template with its synchronous handler function. Prompts + * provide structured templates for AI model interactions, supporting: + *
      + *
    • Consistent message formatting + *
    • Parameter substitution + *
    • Context injection + *
    • Response formatting + *
    • Instruction templating + *
    + * + *

    + * Example prompt specification:

    {@code
    +	 * new McpServerFeatures.SyncPromptSpecification(
    +	 *     new Prompt("analyze", "Code analysis template"),
    +	 *     (exchange, request) -> {
    +	 *         String code = request.getArguments().get("code");
    +	 *         return new GetPromptResult(
    +	 *             "Analyze this code:\n\n" + code + "\n\nProvide feedback on:"
    +	 *         );
    +	 *     }
    +	 * )
    +	 * }
    + * + * @param prompt The prompt definition including name and description + * @param promptHandler The function that processes prompt requests and returns + * formatted templates + */ + public record SyncPromptSpecification(McpSchema.Prompt prompt, + BiFunction promptHandler) { + } + + // --------------------------------------- + // Deprecated registrations + // --------------------------------------- + /** * Registration of a tool with its asynchronous handler function. Tools are the * primary way for MCP servers to expose functionality to AI models. Each tool @@ -210,7 +440,10 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se * @param tool The tool definition including name, description, and parameter schema * @param call The function that implements the tool's logic, receiving arguments and * returning results + * @deprecated This class is deprecated and will be removed in 0.9.0. Use + * {@link AsyncToolSpecification}. */ + @Deprecated public record AsyncToolRegistration(McpSchema.Tool tool, Function, Mono> call) { @@ -222,6 +455,10 @@ static AsyncToolRegistration fromSync(SyncToolRegistration tool) { return new AsyncToolRegistration(tool.tool(), map -> Mono.fromCallable(() -> tool.call().apply(map)).subscribeOn(Schedulers.boundedElastic())); } + + AsyncToolSpecification toSpecification() { + return new AsyncToolSpecification(tool(), (exchange, map) -> call.apply(map)); + } } /** @@ -248,7 +485,10 @@ static AsyncToolRegistration fromSync(SyncToolRegistration tool) { * * @param resource The resource definition including name, description, and MIME type * @param readHandler The function that handles resource read requests + * @deprecated This class is deprecated and will be removed in 0.9.0. Use + * {@link AsyncResourceSpecification}. */ + @Deprecated public record AsyncResourceRegistration(McpSchema.Resource resource, Function> readHandler) { @@ -261,6 +501,10 @@ static AsyncResourceRegistration fromSync(SyncResourceRegistration resource) { req -> Mono.fromCallable(() -> resource.readHandler().apply(req)) .subscribeOn(Schedulers.boundedElastic())); } + + AsyncResourceSpecification toSpecification() { + return new AsyncResourceSpecification(resource(), (exchange, request) -> readHandler.apply(request)); + } } /** @@ -290,7 +534,10 @@ static AsyncResourceRegistration fromSync(SyncResourceRegistration resource) { * @param prompt The prompt definition including name and description * @param promptHandler The function that processes prompt requests and returns * formatted templates + * @deprecated This class is deprecated and will be removed in 0.9.0. Use + * {@link AsyncPromptSpecification}. */ + @Deprecated public record AsyncPromptRegistration(McpSchema.Prompt prompt, Function> promptHandler) { @@ -303,6 +550,10 @@ static AsyncPromptRegistration fromSync(SyncPromptRegistration prompt) { req -> Mono.fromCallable(() -> prompt.promptHandler().apply(req)) .subscribeOn(Schedulers.boundedElastic())); } + + AsyncPromptSpecification toSpecification() { + return new AsyncPromptSpecification(prompt(), (exchange, request) -> promptHandler.apply(request)); + } } /** @@ -337,9 +588,15 @@ static AsyncPromptRegistration fromSync(SyncPromptRegistration prompt) { * @param tool The tool definition including name, description, and parameter schema * @param call The function that implements the tool's logic, receiving arguments and * returning results + * @deprecated This class is deprecated and will be removed in 0.9.0. Use + * {@link SyncToolSpecification}. */ + @Deprecated public record SyncToolRegistration(McpSchema.Tool tool, Function, McpSchema.CallToolResult> call) { + SyncToolSpecification toSpecification() { + return new SyncToolSpecification(tool, (exchange, map) -> call.apply(map)); + } } /** @@ -366,9 +623,15 @@ public record SyncToolRegistration(McpSchema.Tool tool, * * @param resource The resource definition including name, description, and MIME type * @param readHandler The function that handles resource read requests + * @deprecated This class is deprecated and will be removed in 0.9.0. Use + * {@link SyncResourceSpecification}. */ + @Deprecated public record SyncResourceRegistration(McpSchema.Resource resource, Function readHandler) { + SyncResourceSpecification toSpecification() { + return new SyncResourceSpecification(resource, (exchange, request) -> readHandler.apply(request)); + } } /** @@ -398,9 +661,15 @@ public record SyncResourceRegistration(McpSchema.Resource resource, * @param prompt The prompt definition including name and description * @param promptHandler The function that processes prompt requests and returns * formatted templates + * @deprecated This class is deprecated and will be removed in 0.9.0. Use + * {@link SyncPromptSpecification}. */ + @Deprecated public record SyncPromptRegistration(McpSchema.Prompt prompt, Function promptHandler) { + SyncPromptSpecification toSpecification() { + return new SyncPromptSpecification(prompt, (exchange, request) -> promptHandler.apply(request)); + } } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index e8b24c7c..568a655d 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -22,6 +22,7 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; From e1f1ced9e4c2fd176a2605f891ce3c8f9c84e3de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Fri, 14 Mar 2025 14:29:26 +0100 Subject: [PATCH 08/20] IT method usage fix --- .../modelcontextprotocol/WebFluxSseIntegrationTests.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index d8f56d04..43253338 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -384,8 +384,8 @@ void testToolCallSuccess(String clientType) { var clientBuilder = clientBulders.get(clientType); var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { + McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -424,8 +424,8 @@ void testToolListChangeHandlingSuccess(String clientType) { var clientBuilder = clientBulders.get(clientType); var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { + McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { // perform a blocking call to a remote service String response = RestClient.create() .get() From 962cffad38d1aa434d0cd72aa68bcff3a68b0dcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Fri, 14 Mar 2025 23:31:46 +0100 Subject: [PATCH 09/20] Incorporate dynamic addition of tools, resources, and prompts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../server/McpAsyncServer.java | 122 ++++++++++++------ .../server/McpSyncServer.java | 33 +++++ 2 files changed, 112 insertions(+), 43 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index bc1e5e12..0610e7bf 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -185,11 +185,23 @@ public Mono listRoots(String cursor) { * Add a new tool registration at runtime. * @param toolRegistration The tool registration to add * @return Mono that completes when clients have been notified of the change + * @deprecated This method will be removed in 0.9.0. Use + * {@link #addTool(McpServerFeatures.AsyncToolSpecification)}. */ + @Deprecated public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistration) { return this.delegate.addTool(toolRegistration); } + /** + * Add a new tool specification at runtime. + * @param toolSpecification The tool specification to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) { + return this.delegate.addTool(toolSpecification); + } + /** * Remove a tool handler at runtime. * @param toolName The name of the tool handler to remove @@ -215,11 +227,23 @@ public Mono notifyToolsListChanged() { * Add a new resource handler at runtime. * @param resourceHandler The resource handler to add * @return Mono that completes when clients have been notified of the change + * @deprecated This method will be removed in 0.9.0. Use + * {@link #addResource(McpServerFeatures.AsyncResourceSpecification)}. */ + @Deprecated public Mono addResource(McpServerFeatures.AsyncResourceRegistration resourceHandler) { return this.delegate.addResource(resourceHandler); } + /** + * Add a new resource handler at runtime. + * @param resourceHandler The resource handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceHandler) { + return this.delegate.addResource(resourceHandler); + } + /** * Remove a resource handler at runtime. * @param resourceUri The URI of the resource handler to remove @@ -245,11 +269,23 @@ public Mono notifyResourcesListChanged() { * Add a new prompt handler at runtime. * @param promptRegistration The prompt handler to add * @return Mono that completes when clients have been notified of the change + * @deprecated This method will be removed in 0.9.0. Use + * {@link #addPrompt(McpServerFeatures.AsyncPromptSpecification)}. */ + @Deprecated public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegistration) { return this.delegate.addPrompt(promptRegistration); } + /** + * Add a new prompt handler at runtime. + * @param promptSpecification The prompt handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) { + return this.delegate.addPrompt(promptSpecification); + } + /** * Remove a prompt handler at runtime. * @param promptName The name of the prompt handler to remove @@ -536,20 +572,8 @@ private McpServerSession.NotificationHandler asyncRootsListChangedNotificationHa // --------------------------------------- /** - * Add a new tool registration at runtime. - * @param toolRegistration The tool registration to add - * @return Mono that completes when clients have been notified of the change - * @deprecated This method will be removed in 0.9.0. Use - * {@link #addTool(McpServerFeatures.AsyncToolSpecification)}. - */ - @Deprecated - public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistration) { - return this.addTool(toolRegistration.toSpecification()); - } - - /** - * Add a new tool registration at runtime. - * @param toolSpecification The tool registration to add + * Add a new tool specification at runtime. + * @param toolSpecification The tool specification to add * @return Mono that completes when clients have been notified of the change */ public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) { @@ -583,6 +607,11 @@ public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecifica }); } + @Override + public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistration) { + return this.addTool(toolRegistration.toSpecification()); + } + /** * Remove a tool handler at runtime. * @param toolName The name of the tool handler to remove @@ -598,7 +627,7 @@ public Mono removeTool(String toolName) { return Mono.defer(() -> { boolean removed = this.tools - .removeIf(toolRegistration -> toolRegistration.tool().name().equals(toolName)); + .removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName)); if (removed) { logger.debug("Removed tool handler: {}", toolName); if (this.serverCapabilities.tools().listChanged()) { @@ -649,18 +678,6 @@ private McpServerSession.RequestHandler toolsCallRequestHandler( // Resource Management // --------------------------------------- - /** - * Add a new resource handler at runtime. - * @param resourceHandler The resource handler to add - * @return Mono that completes when clients have been notified of the change - * @deprecated This method will be removed in 0.9.0. Use - * {@link #addResource(McpServerFeatures.AsyncResourceSpecification)}. - */ - @Deprecated - public Mono addResource(McpServerFeatures.AsyncResourceRegistration resourceHandler) { - return this.addResource(resourceHandler.toSpecification()); - } - /** * Add a new resource handler at runtime. * @param resourceSpecification The resource handler to add @@ -688,6 +705,11 @@ public Mono addResource(McpServerFeatures.AsyncResourceSpecification resou }); } + @Override + public Mono addResource(McpServerFeatures.AsyncResourceRegistration resourceHandler) { + return this.addResource(resourceHandler.toSpecification()); + } + /** * Remove a resource handler at runtime. * @param resourceUri The URI of the resource handler to remove @@ -744,9 +766,9 @@ private McpServerSession.RequestHandler resourcesR new TypeReference() { }); var resourceUri = resourceRequest.uri(); - McpServerFeatures.AsyncResourceSpecification registration = this.resources.get(resourceUri); - if (registration != null) { - return registration.readHandler().apply(exchange, resourceRequest); + McpServerFeatures.AsyncResourceSpecification specification = this.resources.get(resourceUri); + if (specification != null) { + return specification.readHandler().apply(exchange, resourceRequest); } return Mono.error(new McpError("Resource not found: " + resourceUri)); }; @@ -756,18 +778,6 @@ private McpServerSession.RequestHandler resourcesR // Prompt Management // --------------------------------------- - /** - * Add a new prompt handler at runtime. - * @param promptRegistration The prompt handler to add - * @return Mono that completes when clients have been notified of the change - * @deprecated This method will be removed in 0.9.0. Use - * {@link #addPrompt(McpServerFeatures.AsyncPromptSpecification)}. - */ - @Deprecated - public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegistration) { - return this.addPrompt(promptRegistration.toSpecification()); - } - /** * Add a new prompt handler at runtime. * @param promptSpecification The prompt handler to add @@ -775,7 +785,7 @@ public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegi */ public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) { if (promptSpecification == null) { - return Mono.error(new McpError("Prompt registration must not be null")); + return Mono.error(new McpError("Prompt specification must not be null")); } if (this.serverCapabilities.prompts() == null) { return Mono.error(new McpError("Server must be configured with prompt capabilities")); @@ -801,6 +811,11 @@ public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpe }); } + @Override + public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegistration) { + return this.addPrompt(promptRegistration.toSpecification()); + } + /** * Remove a prompt handler at runtime. * @param promptName The name of the prompt handler to remove @@ -1056,6 +1071,24 @@ private static final class LegacyAsyncServer extends McpAsyncServer { notificationHandlers); } + @Override + public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) { + throw new IllegalArgumentException( + "McpAsyncServer configured with legacy " + "transport. Use McpServerTransportProvider instead."); + } + + @Override + public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceHandler) { + throw new IllegalArgumentException( + "McpAsyncServer configured with legacy " + "transport. Use McpServerTransportProvider instead."); + } + + @Override + public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) { + throw new IllegalArgumentException( + "McpAsyncServer configured with legacy " + "transport. Use McpServerTransportProvider instead."); + } + // --------------------------------------- // Lifecycle Management // --------------------------------------- @@ -1182,6 +1215,7 @@ private DefaultMcpSession.NotificationHandler asyncRootsListChangedNotificationH * @param toolRegistration The tool registration to add * @return Mono that completes when clients have been notified of the change */ + @Override public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistration) { if (toolRegistration == null) { return Mono.error(new McpError("Tool registration must not be null")); @@ -1284,6 +1318,7 @@ private DefaultMcpSession.RequestHandler toolsCallRequestHandler * @param resourceHandler The resource handler to add * @return Mono that completes when clients have been notified of the change */ + @Override public Mono addResource(McpServerFeatures.AsyncResourceRegistration resourceHandler) { if (resourceHandler == null || resourceHandler.resource() == null) { return Mono.error(new McpError("Resource must not be null")); @@ -1379,6 +1414,7 @@ private DefaultMcpSession.RequestHandler resources * @param promptRegistration The prompt handler to add * @return Mono that completes when clients have been notified of the change */ + @Override public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegistration) { if (promptRegistration == null) { return Mono.error(new McpError("Prompt registration must not be null")); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java index b214848e..bba5b059 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java @@ -87,11 +87,22 @@ public McpSchema.ListRootsResult listRoots(String cursor) { /** * Add a new tool handler. * @param toolHandler The tool handler to add + * @deprecated This method will be removed in 0.9.0. Use + * {@link #addTool(McpServerFeatures.SyncToolSpecification)}. */ + @Deprecated public void addTool(McpServerFeatures.SyncToolRegistration toolHandler) { this.asyncServer.addTool(McpServerFeatures.AsyncToolRegistration.fromSync(toolHandler)).block(); } + /** + * Add a new tool handler. + * @param toolHandler The tool handler to add + */ + public void addTool(McpServerFeatures.SyncToolSpecification toolHandler) { + this.asyncServer.addTool(McpServerFeatures.AsyncToolSpecification.fromSync(toolHandler)).block(); + } + /** * Remove a tool handler. * @param toolName The name of the tool handler to remove @@ -103,11 +114,22 @@ public void removeTool(String toolName) { /** * Add a new resource handler. * @param resourceHandler The resource handler to add + * @deprecated This method will be removed in 0.9.0. Use + * {@link #addResource(McpServerFeatures.SyncResourceSpecification)}. */ + @Deprecated public void addResource(McpServerFeatures.SyncResourceRegistration resourceHandler) { this.asyncServer.addResource(McpServerFeatures.AsyncResourceRegistration.fromSync(resourceHandler)).block(); } + /** + * Add a new resource handler. + * @param resourceHandler The resource handler to add + */ + public void addResource(McpServerFeatures.SyncResourceSpecification resourceHandler) { + this.asyncServer.addResource(McpServerFeatures.AsyncResourceSpecification.fromSync(resourceHandler)).block(); + } + /** * Remove a resource handler. * @param resourceUri The URI of the resource handler to remove @@ -119,11 +141,22 @@ public void removeResource(String resourceUri) { /** * Add a new prompt handler. * @param promptRegistration The prompt registration to add + * @deprecated This method will be removed in 0.9.0. Use + * {@link #addPrompt(McpServerFeatures.SyncPromptSpecification)}. */ + @Deprecated public void addPrompt(McpServerFeatures.SyncPromptRegistration promptRegistration) { this.asyncServer.addPrompt(McpServerFeatures.AsyncPromptRegistration.fromSync(promptRegistration)).block(); } + /** + * Add a new prompt handler. + * @param promptSpecification The prompt specification to add + */ + public void addPrompt(McpServerFeatures.SyncPromptSpecification promptSpecification) { + this.asyncServer.addPrompt(McpServerFeatures.AsyncPromptSpecification.fromSync(promptSpecification)).block(); + } + /** * Remove a prompt handler. * @param promptName The name of the prompt handler to remove From 76c802f38750daed86f252494d24a3f327816e61 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Sun, 16 Mar 2025 17:21:00 +0100 Subject: [PATCH 10/20] fix: update sync tool/resource/prompt handlers to use McpSyncServerExchange This commit changes the parameter type in sync handler interfaces from McpAsyncServerExchange to McpSyncServerExchange for better API consistency. It also adds proper wrapping of async exchanges in sync exchanges when bridging between async and sync contexts in the adapter methods. Signed-off-by: Christian Tzolov --- .../io/modelcontextprotocol/server/McpServer.java | 2 +- .../server/McpServerFeatures.java | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index 7c4eb6dc..1b78d5a7 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -668,7 +668,7 @@ public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabil * @throws IllegalArgumentException if tool or handler is null */ public SyncSpecification tool(McpSchema.Tool tool, - BiFunction, McpSchema.CallToolResult> handler) { + BiFunction, McpSchema.CallToolResult> handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java index d3c9ea63..2ba12f2a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -220,7 +220,8 @@ static AsyncToolSpecification fromSync(SyncToolSpecification tool) { return null; } return new AsyncToolSpecification(tool.tool(), - (exchange, map) -> Mono.fromCallable(() -> tool.call().apply(exchange, map)) + (exchange, map) -> Mono + .fromCallable(() -> tool.call().apply(new McpSyncServerExchange(exchange), map)) .subscribeOn(Schedulers.boundedElastic())); } } @@ -259,7 +260,8 @@ static AsyncResourceSpecification fromSync(SyncResourceSpecification resource) { return null; } return new AsyncResourceSpecification(resource.resource(), - (exchange, req) -> Mono.fromCallable(() -> resource.readHandler().apply(exchange, req)) + (exchange, req) -> Mono + .fromCallable(() -> resource.readHandler().apply(new McpSyncServerExchange(exchange), req)) .subscribeOn(Schedulers.boundedElastic())); } } @@ -301,7 +303,8 @@ static AsyncPromptSpecification fromSync(SyncPromptSpecification prompt) { return null; } return new AsyncPromptSpecification(prompt.prompt(), - (exchange, req) -> Mono.fromCallable(() -> prompt.promptHandler().apply(exchange, req)) + (exchange, req) -> Mono + .fromCallable(() -> prompt.promptHandler().apply(new McpSyncServerExchange(exchange), req)) .subscribeOn(Schedulers.boundedElastic())); } } @@ -340,7 +343,7 @@ static AsyncPromptSpecification fromSync(SyncPromptSpecification prompt) { * returning results */ public record SyncToolSpecification(McpSchema.Tool tool, - BiFunction, McpSchema.CallToolResult> call) { + BiFunction, McpSchema.CallToolResult> call) { } /** @@ -369,7 +372,7 @@ public record SyncToolSpecification(McpSchema.Tool tool, * @param readHandler The function that handles resource read requests */ public record SyncResourceSpecification(McpSchema.Resource resource, - BiFunction readHandler) { + BiFunction readHandler) { } /** @@ -401,7 +404,7 @@ public record SyncResourceSpecification(McpSchema.Resource resource, * formatted templates */ public record SyncPromptSpecification(McpSchema.Prompt prompt, - BiFunction promptHandler) { + BiFunction promptHandler) { } // --------------------------------------- From 5015ca6b1034c7051701078a5196d20c292188f2 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Mon, 17 Mar 2025 12:01:52 +0100 Subject: [PATCH 11/20] refactor: Add StdioServerTransportProvider with reactive streams Add StdioServerTransportProvider implementation that uses reactive streams for both inbound and outbound message processing. - Using Flux for asynchronous message handling - Implementing separate inbound and outbound processing pipelines - Improving error handling with proper propagation Signed-off-by: Christian Tzolov --- ...bFluxSseMcpAsyncServerDeprecatedTests.java | 55 +++ .../server/WebFluxSseMcpAsyncServerTests.java | 15 +- ...ebFluxSseMcpSyncServerDeprecatecTests.java | 55 +++ .../server/WebFluxSseMcpSyncServerTests.java | 16 +- .../WebMvcSseAsyncServerTransportTests.java | 2 +- .../WebMvcSseSyncServerTransportTests.java | 2 +- ...AbstractMcpAsyncServerDeprecatedTests.java | 465 +++++++++++++++++ .../server/AbstractMcpAsyncServerTests.java | 126 ++--- .../AbstractMcpSyncServerDeprecatedTests.java | 431 ++++++++++++++++ .../server/AbstractMcpSyncServerTests.java | 131 ++--- .../server/McpServer.java | 4 - .../StdioServerTransportProvider.java | 306 ++++++++++++ ...AbstractMcpAsyncServerDeprecatedTests.java | 466 ++++++++++++++++++ .../server/AbstractMcpAsyncServerTests.java | 127 ++--- .../AbstractMcpSyncServerDeprecatedTests.java | 433 ++++++++++++++++ .../server/AbstractMcpSyncServerTests.java | 131 ++--- .../server/ServletSseMcpAsyncServerTests.java | 2 +- .../server/ServletSseMcpSyncServerTests.java | 2 +- .../StdioMcpAsyncServerDeprecatedTests.java | 25 + .../server/StdioMcpAsyncServerTests.java | 7 +- .../StdioMcpSyncServerDeprecatedTests.java | 25 + .../server/StdioMcpSyncServerTests.java | 10 +- .../server/transport/BlockingInputStream.java | 69 --- .../StdioServerTransportProviderTests.java | 227 +++++++++ 24 files changed, 2784 insertions(+), 348 deletions(-) create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerDeprecatedTests.java create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerDeprecatecTests.java create mode 100644 mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java create mode 100644 mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerDeprecatedTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerDeprecatedTests.java delete mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerDeprecatedTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerDeprecatedTests.java new file mode 100644 index 00000000..b460284e --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerDeprecatedTests.java @@ -0,0 +1,55 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.Timeout; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.server.RouterFunctions; + +/** + * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransport}. + * + * @author Christian Tzolov + */ +@Deprecated +@Timeout(15) // Giving extra time beyond the client timeout +class WebFluxSseMcpAsyncServerDeprecatedTests extends AbstractMcpAsyncServerDeprecatedTests { + + private static final int PORT = 8181; + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private DisposableServer httpServer; + + @Override + protected ServerMcpTransport createMcpTransport() { + var transport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(transport.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + + return transport; + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java index 1ed0d99b..5fa787ab 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java @@ -5,8 +5,8 @@ package io.modelcontextprotocol.server; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; @@ -16,7 +16,7 @@ import org.springframework.web.reactive.function.server.RouterFunctions; /** - * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransport}. + * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransportProvider}. * * @author Christian Tzolov */ @@ -30,14 +30,13 @@ class WebFluxSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { private DisposableServer httpServer; @Override - protected ServerMcpTransport createMcpTransport() { - var transport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + protected McpServerTransportProvider createMcpTransportProvider() { + var transportProvider = new WebFluxSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); - HttpHandler httpHandler = RouterFunctions.toHttpHandler(transport.getRouterFunction()); + HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); - - return transport; + return transportProvider; } @Override diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerDeprecatecTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerDeprecatecTests.java new file mode 100644 index 00000000..be2bf6c7 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerDeprecatecTests.java @@ -0,0 +1,55 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.Timeout; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.server.RouterFunctions; + +/** + * Tests for {@link McpSyncServer} using {@link WebFluxSseServerTransport}. + * + * @author Christian Tzolov + */ +@Deprecated +@Timeout(15) // Giving extra time beyond the client timeout +class WebFluxSseMcpSyncServerDeprecatecTests extends AbstractMcpSyncServerDeprecatedTests { + + private static final int PORT = 8182; + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private DisposableServer httpServer; + + private WebFluxSseServerTransport transport; + + @Override + protected ServerMcpTransport createMcpTransport() { + transport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + return transport; + } + + @Override + protected void onStart() { + HttpHandler httpHandler = RouterFunctions.toHttpHandler(transport.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + } + + @Override + protected void onClose() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java index 4db00dd4..d3672e3f 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java @@ -5,8 +5,8 @@ package io.modelcontextprotocol.server; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; @@ -16,7 +16,7 @@ import org.springframework.web.reactive.function.server.RouterFunctions; /** - * Tests for {@link McpSyncServer} using {@link WebFluxSseServerTransport}. + * Tests for {@link McpSyncServer} using {@link WebFluxSseServerTransportProvider}. * * @author Christian Tzolov */ @@ -29,17 +29,17 @@ class WebFluxSseMcpSyncServerTests extends AbstractMcpSyncServerTests { private DisposableServer httpServer; - private WebFluxSseServerTransport transport; + private WebFluxSseServerTransportProvider transportProvider; @Override - protected ServerMcpTransport createMcpTransport() { - transport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); - return transport; + protected McpServerTransportProvider createMcpTransportProvider() { + transportProvider = new WebFluxSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + return transportProvider; } @Override protected void onStart() { - HttpHandler httpHandler = RouterFunctions.toHttpHandler(transport.getRouterFunction()); + HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java index a819920c..00649f0f 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java @@ -21,7 +21,7 @@ import org.springframework.web.servlet.function.ServerResponse; @Timeout(15) -class WebMvcSseAsyncServerTransportTests extends AbstractMcpAsyncServerTests { +class WebMvcSseAsyncServerTransportTests extends AbstractMcpAsyncServerDeprecatedTests { private static final String MESSAGE_ENDPOINT = "/mcp/message"; diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java index 249b4dea..be843a87 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java @@ -21,7 +21,7 @@ import org.springframework.web.servlet.function.ServerResponse; @Timeout(15) -class WebMvcSseSyncServerTransportTests extends AbstractMcpSyncServerTests { +class WebMvcSseSyncServerTransportTests extends AbstractMcpSyncServerDeprecatedTests { private static final String MESSAGE_ENDPOINT = "/mcp/message"; diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java new file mode 100644 index 00000000..005d78f2 --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java @@ -0,0 +1,465 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.List; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Test suite for the {@link McpAsyncServer} that can be used with different + * {@link McpTransport} implementations. + * + * @author Christian Tzolov + */ +@Deprecated +public abstract class AbstractMcpAsyncServerDeprecatedTests { + + private static final String TEST_TOOL_NAME = "test-tool"; + + private static final String TEST_RESOURCE_URI = "test://resource"; + + private static final String TEST_PROMPT_NAME = "test-prompt"; + + abstract protected ServerMcpTransport createMcpTransport(); + + protected void onStart() { + } + + protected void onClose() { + } + + @BeforeEach + void setUp() { + } + + @AfterEach + void tearDown() { + onClose(); + } + + // --------------------------------------- + // Server Lifecycle Tests + // --------------------------------------- + + @Test + void testConstructorWithInvalidArguments() { + assertThatThrownBy(() -> McpServer.async((ServerMcpTransport) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport must not be null"); + + assertThatThrownBy(() -> McpServer.async(createMcpTransport()).serverInfo((McpSchema.Implementation) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Server info must not be null"); + } + + @Test + void testGracefulShutdown() { + var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(mcpAsyncServer.closeGracefully()).verifyComplete(); + } + + @Test + void testImmediateClose() { + var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @Test + void testAddTool() { + Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(newTool, + args -> Mono.just(new CallToolResult(List.of(), false))))) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddDuplicateTool() { + Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(duplicateTool, args -> Mono.just(new CallToolResult(List.of(), false))) + .build(); + + StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(duplicateTool, + args -> Mono.just(new CallToolResult(List.of(), false))))) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); + }); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveTool() { + Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(too, args -> Mono.just(new CallToolResult(List.of(), false))) + .build(); + + StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentTool() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.removeTool("nonexistent-tool")).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class).hasMessage("Tool with name 'nonexistent-tool' not found"); + }); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testNotifyToolsListChanged() { + Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(too, args -> Mono.just(new CallToolResult(List.of(), false))) + .build(); + + StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resources Tests + // --------------------------------------- + + @Test + void testNotifyResourcesListChanged() { + var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(mcpAsyncServer.notifyResourcesListChanged()).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResource() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", + null); + McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration( + resource, req -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(mcpAsyncServer.addResource(registration)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithNullRegistration() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceRegistration) null)) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class).hasMessage("Resource must not be null"); + }); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .build(); + + Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", + null); + McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration( + resource, req -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(serverWithoutResources.addResource(registration)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with resource capabilities"); + }); + } + + @Test + void testRemoveResourceWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .build(); + + StepVerifier.create(serverWithoutResources.removeResource(TEST_RESOURCE_URI)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with resource capabilities"); + }); + } + + // --------------------------------------- + // Prompts Tests + // --------------------------------------- + + @Test + void testNotifyPromptsListChanged() { + var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(mcpAsyncServer.notifyPromptsListChanged()).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddPromptWithNullRegistration() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(false).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptRegistration) null)) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt registration must not be null"); + }); + } + + @Test + void testAddPromptWithoutCapability() { + // Create a server without prompt capabilities + McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .build(); + + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); + McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt, + req -> Mono.just(new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); + + StepVerifier.create(serverWithoutPrompts.addPrompt(registration)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with prompt capabilities"); + }); + } + + @Test + void testRemovePromptWithoutCapability() { + // Create a server without prompt capabilities + McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .build(); + + StepVerifier.create(serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with prompt capabilities"); + }); + } + + @Test + void testRemovePrompt() { + String TEST_PROMPT_NAME_TO_REMOVE = "TEST_PROMPT_NAME678"; + + Prompt prompt = new Prompt(TEST_PROMPT_NAME_TO_REMOVE, "Test Prompt", List.of()); + McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt, + req -> Mono.just(new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); + + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .prompts(registration) + .build(); + + StepVerifier.create(mcpAsyncServer.removePrompt(TEST_PROMPT_NAME_TO_REMOVE)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentPrompt() { + var mcpAsyncServer2 = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer2.removePrompt("nonexistent-prompt")).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Prompt with name 'nonexistent-prompt' not found"); + }); + + assertThatCode(() -> mcpAsyncServer2.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + + @Test + void testRootsChangeConsumers() { + // Test with single consumer + var rootsReceived = new McpSchema.Root[1]; + var consumerCalled = new boolean[1]; + + var singleConsumerServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> { + consumerCalled[0] = true; + if (!roots.isEmpty()) { + rootsReceived[0] = roots.get(0); + } + }))) + .build(); + + assertThat(singleConsumerServer).isNotNull(); + assertThatCode(() -> singleConsumerServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + onClose(); + + // Test with multiple consumers + var consumer1Called = new boolean[1]; + var consumer2Called = new boolean[1]; + var rootsContent = new List[1]; + + var multipleConsumersServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> { + consumer1Called[0] = true; + rootsContent[0] = roots; + }), roots -> Mono.fromRunnable(() -> consumer2Called[0] = true))) + .build(); + + assertThat(multipleConsumersServer).isNotNull(); + assertThatCode(() -> multipleConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + onClose(); + + // Test error handling + var errorHandlingServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .rootsChangeConsumers(List.of(roots -> { + throw new RuntimeException("Test error"); + })) + .build(); + + assertThat(errorHandlingServer).isNotNull(); + assertThatCode(() -> errorHandlingServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + onClose(); + + // Test without consumers + var noConsumersServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThat(noConsumersServer).isNotNull(); + assertThatCode(() -> noConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + + @Test + void testLoggingLevels() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().build()) + .build(); + + // Test all logging levels + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + var notification = McpSchema.LoggingMessageNotification.builder() + .level(level) + .logger("test-logger") + .data("Test message with level " + level) + .build(); + + StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); + } + } + + @Test + void testLoggingWithoutCapability() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().build()) // No logging capability + .build(); + + var notification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Test log message") + .build(); + + StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); + } + + @Test + void testLoggingWithNullNotification() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().build()) + .build(); + + StepVerifier.create(mcpAsyncServer.loggingNotification(null)).verifyError(McpError.class); + } + +} diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index 725a2167..7bcb9a8b 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -17,8 +17,7 @@ import io.modelcontextprotocol.spec.McpSchema.Resource; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -31,10 +30,11 @@ /** * Test suite for the {@link McpAsyncServer} that can be used with different - * {@link McpTransport} implementations. + * {@link McpTransportProvider} implementations. * * @author Christian Tzolov */ +// KEEP IN SYNC with the class in mcp-test module public abstract class AbstractMcpAsyncServerTests { private static final String TEST_TOOL_NAME = "test-tool"; @@ -43,7 +43,7 @@ public abstract class AbstractMcpAsyncServerTests { private static final String TEST_PROMPT_NAME = "test-prompt"; - abstract protected ServerMcpTransport createMcpTransport(); + abstract protected McpServerTransportProvider createMcpTransportProvider(); protected void onStart() { } @@ -66,25 +66,26 @@ void tearDown() { @Test void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.async((ServerMcpTransport) null)) + assertThatThrownBy(() -> McpServer.async((McpServerTransportProvider) null)) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport must not be null"); + .hasMessage("Transport provider must not be null"); - assertThatThrownBy(() -> McpServer.async(createMcpTransport()).serverInfo((McpSchema.Implementation) null)) + assertThatThrownBy( + () -> McpServer.async(createMcpTransportProvider()).serverInfo((McpSchema.Implementation) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Server info must not be null"); } @Test void testGracefulShutdown() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.closeGracefully()).verifyComplete(); } @Test void testImmediateClose() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException(); } @@ -103,13 +104,13 @@ void testImmediateClose() { @Test void testAddTool() { Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(newTool, - args -> Mono.just(new CallToolResult(List.of(), false))))) + StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, + (excnage, args) -> Mono.just(new CallToolResult(List.of(), false))))) .verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); @@ -119,14 +120,15 @@ void testAddTool() { void testAddDuplicateTool() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, args -> Mono.just(new CallToolResult(List.of(), false))) + .tool(duplicateTool, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) .build(); - StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(duplicateTool, - args -> Mono.just(new CallToolResult(List.of(), false))))) + StepVerifier + .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(duplicateTool, + (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) .verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); @@ -139,10 +141,10 @@ void testAddDuplicateTool() { void testRemoveTool() { Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, args -> Mono.just(new CallToolResult(List.of(), false))) + .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); @@ -152,7 +154,7 @@ void testRemoveTool() { @Test void testRemoveNonexistentTool() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -168,10 +170,10 @@ void testRemoveNonexistentTool() { void testNotifyToolsListChanged() { Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, args -> Mono.just(new CallToolResult(List.of(), false))) + .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); @@ -185,7 +187,7 @@ void testNotifyToolsListChanged() { @Test void testNotifyResourcesListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.notifyResourcesListChanged()).verifyComplete(); @@ -194,29 +196,29 @@ void testNotifyResourcesListChanged() { @Test void testAddResource() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); - McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration( - resource, req -> Mono.just(new ReadResourceResult(List.of()))); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); - StepVerifier.create(mcpAsyncServer.addResource(registration)).verifyComplete(); + StepVerifier.create(mcpAsyncServer.addResource(specification)).verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } @Test - void testAddResourceWithNullRegistration() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + void testAddResourceWithNullSpecification() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); - StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceRegistration) null)) + StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceSpecification) null)) .verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class).hasMessage("Resource must not be null"); }); @@ -227,16 +229,16 @@ void testAddResourceWithNullRegistration() { @Test void testAddResourceWithoutCapability() { // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport()) + McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .build(); Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); - McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration( - resource, req -> Mono.just(new ReadResourceResult(List.of()))); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); - StepVerifier.create(serverWithoutResources.addResource(registration)).verifyErrorSatisfies(error -> { + StepVerifier.create(serverWithoutResources.addResource(specification)).verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) .hasMessage("Server must be configured with resource capabilities"); }); @@ -245,7 +247,7 @@ void testAddResourceWithoutCapability() { @Test void testRemoveResourceWithoutCapability() { // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport()) + McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .build(); @@ -261,7 +263,7 @@ void testRemoveResourceWithoutCapability() { @Test void testNotifyPromptsListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.notifyPromptsListChanged()).verifyComplete(); @@ -269,31 +271,31 @@ void testNotifyPromptsListChanged() { } @Test - void testAddPromptWithNullRegistration() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + void testAddPromptWithNullSpecification() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(false).build()) .build(); - StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptRegistration) null)) + StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptSpecification) null)) .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt registration must not be null"); + assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt specification must not be null"); }); } @Test void testAddPromptWithoutCapability() { // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport()) + McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .build(); Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt, - req -> Mono.just(new GetPromptResult("Test prompt description", List + McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( + prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - StepVerifier.create(serverWithoutPrompts.addPrompt(registration)).verifyErrorSatisfies(error -> { + StepVerifier.create(serverWithoutPrompts.addPrompt(specification)).verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) .hasMessage("Server must be configured with prompt capabilities"); }); @@ -302,7 +304,7 @@ void testAddPromptWithoutCapability() { @Test void testRemovePromptWithoutCapability() { // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport()) + McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .build(); @@ -317,14 +319,14 @@ void testRemovePrompt() { String TEST_PROMPT_NAME_TO_REMOVE = "TEST_PROMPT_NAME678"; Prompt prompt = new Prompt(TEST_PROMPT_NAME_TO_REMOVE, "Test Prompt", List.of()); - McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt, - req -> Mono.just(new GetPromptResult("Test prompt description", List + McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( + prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) - .prompts(registration) + .prompts(specification) .build(); StepVerifier.create(mcpAsyncServer.removePrompt(TEST_PROMPT_NAME_TO_REMOVE)).verifyComplete(); @@ -334,7 +336,7 @@ void testRemovePrompt() { @Test void testRemoveNonexistentPrompt() { - var mcpAsyncServer2 = McpServer.async(createMcpTransport()) + var mcpAsyncServer2 = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); @@ -353,14 +355,14 @@ void testRemoveNonexistentPrompt() { // --------------------------------------- @Test - void testRootsChangeConsumers() { + void testRootsChangeHandlers() { // Test with single consumer var rootsReceived = new McpSchema.Root[1]; var consumerCalled = new boolean[1]; - var singleConsumerServer = McpServer.async(createMcpTransport()) + var singleConsumerServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> { + .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { consumerCalled[0] = true; if (!roots.isEmpty()) { rootsReceived[0] = roots.get(0); @@ -378,12 +380,12 @@ void testRootsChangeConsumers() { var consumer2Called = new boolean[1]; var rootsContent = new List[1]; - var multipleConsumersServer = McpServer.async(createMcpTransport()) + var multipleConsumersServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> { + .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { consumer1Called[0] = true; rootsContent[0] = roots; - }), roots -> Mono.fromRunnable(() -> consumer2Called[0] = true))) + }), (exchange, roots) -> Mono.fromRunnable(() -> consumer2Called[0] = true))) .build(); assertThat(multipleConsumersServer).isNotNull(); @@ -392,9 +394,9 @@ void testRootsChangeConsumers() { onClose(); // Test error handling - var errorHandlingServer = McpServer.async(createMcpTransport()) + var errorHandlingServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { + .rootsChangeHandlers(List.of((exchange, roots) -> { throw new RuntimeException("Test error"); })) .build(); @@ -405,7 +407,9 @@ void testRootsChangeConsumers() { onClose(); // Test without consumers - var noConsumersServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var noConsumersServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); assertThat(noConsumersServer).isNotNull(); assertThatCode(() -> noConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) @@ -418,7 +422,7 @@ void testRootsChangeConsumers() { @Test void testLoggingLevels() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().logging().build()) .build(); @@ -437,7 +441,7 @@ void testLoggingLevels() { @Test void testLoggingWithoutCapability() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().build()) // No logging capability .build(); @@ -453,7 +457,7 @@ void testLoggingWithoutCapability() { @Test void testLoggingWithNullNotification() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().logging().build()) .build(); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java new file mode 100644 index 00000000..c6625aca --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java @@ -0,0 +1,431 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.util.List; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Test suite for the {@link McpSyncServer} that can be used with different + * {@link McpTransport} implementations. + * + * @author Christian Tzolov + */ +public abstract class AbstractMcpSyncServerDeprecatedTests { + + private static final String TEST_TOOL_NAME = "test-tool"; + + private static final String TEST_RESOURCE_URI = "test://resource"; + + private static final String TEST_PROMPT_NAME = "test-prompt"; + + abstract protected ServerMcpTransport createMcpTransport(); + + protected void onStart() { + } + + protected void onClose() { + } + + @BeforeEach + void setUp() { + // onStart(); + } + + @AfterEach + void tearDown() { + onClose(); + } + + // --------------------------------------- + // Server Lifecycle Tests + // --------------------------------------- + + @Test + void testConstructorWithInvalidArguments() { + assertThatThrownBy(() -> McpServer.sync((ServerMcpTransport) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport must not be null"); + + assertThatThrownBy(() -> McpServer.sync(createMcpTransport()).serverInfo(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Server info must not be null"); + } + + @Test + void testGracefulShutdown() { + var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testImmediateClose() { + var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.close()).doesNotThrowAnyException(); + } + + @Test + void testGetAsyncServer() { + var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @Test + void testAddTool() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); + assertThatCode(() -> mcpSyncServer + .addTool(new McpServerFeatures.SyncToolRegistration(newTool, args -> new CallToolResult(List.of(), false)))) + .doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddDuplicateTool() { + Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(duplicateTool, args -> new CallToolResult(List.of(), false)) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolRegistration(duplicateTool, + args -> new CallToolResult(List.of(), false)))) + .isInstanceOf(McpError.class) + .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testRemoveTool() { + Tool tool = new McpSchema.Tool(TEST_TOOL_NAME, "Test tool", emptyJsonSchema); + + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(tool, args -> new CallToolResult(List.of(), false)) + .build(); + + assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentTool() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.removeTool("nonexistent-tool")).isInstanceOf(McpError.class) + .hasMessage("Tool with name 'nonexistent-tool' not found"); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testNotifyToolsListChanged() { + var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.notifyToolsListChanged()).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resources Tests + // --------------------------------------- + + @Test + void testNotifyResourcesListChanged() { + var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.notifyResourcesListChanged()).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddResource() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", + null); + McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration( + resource, req -> new ReadResourceResult(List.of())); + + assertThatCode(() -> mcpSyncServer.addResource(registration)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithNullRegistration() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceRegistration) null)) + .isInstanceOf(McpError.class) + .hasMessage("Resource must not be null"); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithoutCapability() { + var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", + null); + McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration( + resource, req -> new ReadResourceResult(List.of())); + + assertThatThrownBy(() -> serverWithoutResources.addResource(registration)).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with resource capabilities"); + } + + @Test + void testRemoveResourceWithoutCapability() { + var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with resource capabilities"); + } + + // --------------------------------------- + // Prompts Tests + // --------------------------------------- + + @Test + void testNotifyPromptsListChanged() { + var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.notifyPromptsListChanged()).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddPromptWithNullRegistration() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(false).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptRegistration) null)) + .isInstanceOf(McpError.class) + .hasMessage("Prompt registration must not be null"); + } + + @Test + void testAddPromptWithoutCapability() { + var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); + McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt, + req -> new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); + + assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(registration)).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with prompt capabilities"); + } + + @Test + void testRemovePromptWithoutCapability() { + var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with prompt capabilities"); + } + + @Test + void testRemovePrompt() { + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); + McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt, + req -> new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); + + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .prompts(registration) + .build(); + + assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentPrompt() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.removePrompt("nonexistent-prompt")).isInstanceOf(McpError.class) + .hasMessage("Prompt with name 'nonexistent-prompt' not found"); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + + @Test + void testRootsChangeConsumers() { + // Test with single consumer + var rootsReceived = new McpSchema.Root[1]; + var consumerCalled = new boolean[1]; + + var singleConsumerServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .rootsChangeConsumers(List.of(roots -> { + consumerCalled[0] = true; + if (!roots.isEmpty()) { + rootsReceived[0] = roots.get(0); + } + })) + .build(); + + assertThat(singleConsumerServer).isNotNull(); + assertThatCode(() -> singleConsumerServer.closeGracefully()).doesNotThrowAnyException(); + onClose(); + + // Test with multiple consumers + var consumer1Called = new boolean[1]; + var consumer2Called = new boolean[1]; + var rootsContent = new List[1]; + + var multipleConsumersServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .rootsChangeConsumers(List.of(roots -> { + consumer1Called[0] = true; + rootsContent[0] = roots; + }, roots -> consumer2Called[0] = true)) + .build(); + + assertThat(multipleConsumersServer).isNotNull(); + assertThatCode(() -> multipleConsumersServer.closeGracefully()).doesNotThrowAnyException(); + onClose(); + + // Test error handling + var errorHandlingServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .rootsChangeConsumers(List.of(roots -> { + throw new RuntimeException("Test error"); + })) + .build(); + + assertThat(errorHandlingServer).isNotNull(); + assertThatCode(() -> errorHandlingServer.closeGracefully()).doesNotThrowAnyException(); + onClose(); + + // Test without consumers + var noConsumersServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThat(noConsumersServer).isNotNull(); + assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + + @Test + void testLoggingLevels() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().build()) + .build(); + + // Test all logging levels + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + var notification = McpSchema.LoggingMessageNotification.builder() + .level(level) + .logger("test-logger") + .data("Test message with level " + level) + .build(); + + assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); + } + } + + @Test + void testLoggingWithoutCapability() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().build()) // No logging capability + .build(); + + var notification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Test log message") + .build(); + + assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); + } + + @Test + void testLoggingWithNullNotification() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.loggingNotification(null)).isInstanceOf(McpError.class); + } + +} diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index af147f9d..7846e053 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -16,8 +16,7 @@ import io.modelcontextprotocol.spec.McpSchema.Resource; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -28,10 +27,11 @@ /** * Test suite for the {@link McpSyncServer} that can be used with different - * {@link McpTransport} implementations. + * {@link McpTransportProvider} implementations. * * @author Christian Tzolov */ +// KEEP IN SYNC with the class in mcp-test module public abstract class AbstractMcpSyncServerTests { private static final String TEST_TOOL_NAME = "test-tool"; @@ -40,7 +40,7 @@ public abstract class AbstractMcpSyncServerTests { private static final String TEST_PROMPT_NAME = "test-prompt"; - abstract protected ServerMcpTransport createMcpTransport(); + abstract protected McpServerTransportProvider createMcpTransportProvider(); protected void onStart() { } @@ -64,31 +64,32 @@ void tearDown() { @Test void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.sync((ServerMcpTransport) null)).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport must not be null"); + assertThatThrownBy(() -> McpServer.sync((McpServerTransportProvider) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport provider must not be null"); - assertThatThrownBy(() -> McpServer.sync(createMcpTransport()).serverInfo(null)) + assertThatThrownBy(() -> McpServer.sync(createMcpTransportProvider()).serverInfo(null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Server info must not be null"); } @Test void testGracefulShutdown() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); } @Test void testImmediateClose() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.close()).doesNotThrowAnyException(); } @Test void testGetAsyncServer() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); @@ -109,14 +110,14 @@ void testGetAsyncServer() { @Test void testAddTool() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - assertThatCode(() -> mcpSyncServer - .addTool(new McpServerFeatures.SyncToolRegistration(newTool, args -> new CallToolResult(List.of(), false)))) + assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(newTool, + (exchange, args) -> new CallToolResult(List.of(), false)))) .doesNotThrowAnyException(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); @@ -126,14 +127,14 @@ void testAddTool() { void testAddDuplicateTool() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, args -> new CallToolResult(List.of(), false)) + .tool(duplicateTool, (exchange, args) -> new CallToolResult(List.of(), false)) .build(); - assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolRegistration(duplicateTool, - args -> new CallToolResult(List.of(), false)))) + assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool, + (exchange, args) -> new CallToolResult(List.of(), false)))) .isInstanceOf(McpError.class) .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); @@ -144,10 +145,10 @@ void testAddDuplicateTool() { void testRemoveTool() { Tool tool = new McpSchema.Tool(TEST_TOOL_NAME, "Test tool", emptyJsonSchema); - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(tool, args -> new CallToolResult(List.of(), false)) + .tool(tool, (exchange, args) -> new CallToolResult(List.of(), false)) .build(); assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); @@ -157,7 +158,7 @@ void testRemoveTool() { @Test void testRemoveNonexistentTool() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -170,7 +171,7 @@ void testRemoveNonexistentTool() { @Test void testNotifyToolsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.notifyToolsListChanged()).doesNotThrowAnyException(); @@ -183,7 +184,7 @@ void testNotifyToolsListChanged() { @Test void testNotifyResourcesListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.notifyResourcesListChanged()).doesNotThrowAnyException(); @@ -192,29 +193,29 @@ void testNotifyResourcesListChanged() { @Test void testAddResource() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); - McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration( - resource, req -> new ReadResourceResult(List.of())); + McpServerFeatures.SyncResourceSpecification specificaiton = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); - assertThatCode(() -> mcpSyncServer.addResource(registration)).doesNotThrowAnyException(); + assertThatCode(() -> mcpSyncServer.addResource(specificaiton)).doesNotThrowAnyException(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); } @Test - void testAddResourceWithNullRegistration() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + void testAddResourceWithNullSpecifiation() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceRegistration) null)) + assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceSpecification) null)) .isInstanceOf(McpError.class) .hasMessage("Resource must not be null"); @@ -223,20 +224,24 @@ void testAddResourceWithNullRegistration() { @Test void testAddResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var serverWithoutResources = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); - McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration( - resource, req -> new ReadResourceResult(List.of())); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); - assertThatThrownBy(() -> serverWithoutResources.addResource(registration)).isInstanceOf(McpError.class) + assertThatThrownBy(() -> serverWithoutResources.addResource(specification)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with resource capabilities"); } @Test void testRemoveResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var serverWithoutResources = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with resource capabilities"); @@ -248,7 +253,7 @@ void testRemoveResourceWithoutCapability() { @Test void testNotifyPromptsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.notifyPromptsListChanged()).doesNotThrowAnyException(); @@ -256,33 +261,37 @@ void testNotifyPromptsListChanged() { } @Test - void testAddPromptWithNullRegistration() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + void testAddPromptWithNullSpecification() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(false).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptRegistration) null)) + assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptSpecification) null)) .isInstanceOf(McpError.class) - .hasMessage("Prompt registration must not be null"); + .hasMessage("Prompt specification must not be null"); } @Test void testAddPromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt, - req -> new GetPromptResult("Test prompt description", List + McpServerFeatures.SyncPromptSpecification specificaiton = new McpServerFeatures.SyncPromptSpecification(prompt, + (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(registration)).isInstanceOf(McpError.class) + assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specificaiton)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with prompt capabilities"); } @Test void testRemovePromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with prompt capabilities"); @@ -291,14 +300,14 @@ void testRemovePromptWithoutCapability() { @Test void testRemovePrompt() { Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt, - req -> new GetPromptResult("Test prompt description", List + McpServerFeatures.SyncPromptSpecification specificaiton = new McpServerFeatures.SyncPromptSpecification(prompt, + (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) - .prompts(registration) + .prompts(specificaiton) .build(); assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); @@ -308,7 +317,7 @@ void testRemovePrompt() { @Test void testRemoveNonexistentPrompt() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); @@ -324,14 +333,14 @@ void testRemoveNonexistentPrompt() { // --------------------------------------- @Test - void testRootsChangeConsumers() { + void testRootsChangeHandlers() { // Test with single consumer var rootsReceived = new McpSchema.Root[1]; var consumerCalled = new boolean[1]; - var singleConsumerServer = McpServer.sync(createMcpTransport()) + var singleConsumerServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { + .rootsChangeHandlers(List.of((exchage, roots) -> { consumerCalled[0] = true; if (!roots.isEmpty()) { rootsReceived[0] = roots.get(0); @@ -348,12 +357,12 @@ void testRootsChangeConsumers() { var consumer2Called = new boolean[1]; var rootsContent = new List[1]; - var multipleConsumersServer = McpServer.sync(createMcpTransport()) + var multipleConsumersServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { + .rootsChangeHandlers(List.of((exchange, roots) -> { consumer1Called[0] = true; rootsContent[0] = roots; - }, roots -> consumer2Called[0] = true)) + }, (exchange, roots) -> consumer2Called[0] = true)) .build(); assertThat(multipleConsumersServer).isNotNull(); @@ -361,9 +370,9 @@ void testRootsChangeConsumers() { onClose(); // Test error handling - var errorHandlingServer = McpServer.sync(createMcpTransport()) + var errorHandlingServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { + .rootsChangeHandlers(List.of((exchange, roots) -> { throw new RuntimeException("Test error"); })) .build(); @@ -373,7 +382,7 @@ void testRootsChangeConsumers() { onClose(); // Test without consumers - var noConsumersServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var noConsumersServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThat(noConsumersServer).isNotNull(); assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); @@ -385,7 +394,7 @@ void testRootsChangeConsumers() { @Test void testLoggingLevels() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().logging().build()) .build(); @@ -404,7 +413,7 @@ void testLoggingLevels() { @Test void testLoggingWithoutCapability() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().build()) // No logging capability .build(); @@ -420,7 +429,7 @@ void testLoggingWithoutCapability() { @Test void testLoggingWithNullNotification() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().logging().build()) .build(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index 1b78d5a7..81a0ed44 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -879,8 +879,6 @@ public SyncSpecification prompts(McpServerFeatures.SyncPromptSpecification... pr * @param handler The handler to register. Must not be null. * @return This builder instance for method chaining * @throws IllegalArgumentException if consumer is null - * @deprecated This method will be removed in 0.9.0. Use - * {@link #rootsChangeHandler(BiConsumer)}. */ public SyncSpecification rootsChangeHandler(BiConsumer> handler) { Assert.notNull(handler, "Consumer must not be null"); @@ -895,8 +893,6 @@ public SyncSpecification rootsChangeHandler(BiConsumer>> handlers) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java new file mode 100644 index 00000000..6a7d2903 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java @@ -0,0 +1,306 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.io.Reader; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +/** + * Implementation of the MCP Stdio transport provider for servers that communicates using + * standard input/output streams. Messages are exchanged as newline-delimited JSON-RPC + * messages over stdin/stdout, with errors and debug information sent to stderr. + * + * @author Christian Tzolov + */ +public class StdioServerTransportProvider implements McpServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(StdioServerTransportProvider.class); + + private final ObjectMapper objectMapper; + + private final InputStream inputStream; + + private final OutputStream outputStream; + + private McpServerSession session; + + private final AtomicBoolean isClosing = new AtomicBoolean(false); + + private final Sinks.One inboundReady = Sinks.one(); + + /** + * Creates a new StdioServerTransportProvider with a default ObjectMapper and System + * streams. + */ + public StdioServerTransportProvider() { + this(new ObjectMapper()); + } + + /** + * Creates a new StdioServerTransportProvider with the specified ObjectMapper and + * System streams. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + */ + public StdioServerTransportProvider(ObjectMapper objectMapper) { + this(objectMapper, System.in, System.out); + } + + /** + * Creates a new StdioServerTransportProvider with the specified ObjectMapper and + * streams. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * @param inputStream The input stream to read from + * @param outputStream The output stream to write to + */ + public StdioServerTransportProvider(ObjectMapper objectMapper, InputStream inputStream, OutputStream outputStream) { + Assert.notNull(objectMapper, "The ObjectMapper can not be null"); + Assert.notNull(inputStream, "The InputStream can not be null"); + Assert.notNull(outputStream, "The OutputStream can not be null"); + + this.objectMapper = objectMapper; + this.inputStream = inputStream; + this.outputStream = outputStream; + } + + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + // Create a single session for the stdio connection + this.session = sessionFactory.create(new StdioMcpSessionTransport()); + } + + @Override + public Mono notifyClients(String method, Map params) { + if (this.session == null) { + return Mono.error(new McpError("No session to close")); + } + return this.session.sendNotification(method, params) + .doOnError(e -> logger.error("Failed to send notification: {}", e.getMessage())); + } + + @Override + public Mono closeGracefully() { + if (this.session == null) { + return Mono.empty(); + } + return this.session.closeGracefully(); + } + + /** + * Implementation of McpServerTransport for the stdio session. + */ + private class StdioMcpSessionTransport implements McpServerTransport { + + private final Sinks.Many inboundSink; + + private final Sinks.Many outboundSink; + + private final AtomicBoolean isStarted = new AtomicBoolean(false); + + /** Scheduler for handling inbound messages */ + private Scheduler inboundScheduler; + + /** Scheduler for handling outbound messages */ + private Scheduler outboundScheduler; + + private final Sinks.One outboundReady = Sinks.one(); + + public StdioMcpSessionTransport() { + + this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); + this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); + + // Use bounded schedulers for better resource management + this.inboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), + "stdio-inbound"); + this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), + "stdio-outbound"); + + handleIncomingMessages(); + startInboundProcessing(); + startOutboundProcessing(); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + + return Mono.zip(inboundReady.asMono(), outboundReady.asMono()).then(Mono.defer(() -> { + if (outboundSink.tryEmitNext(message).isSuccess()) { + return Mono.empty(); + } + else { + return Mono.error(new RuntimeException("Failed to enqueue message")); + } + })); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + isClosing.set(true); + logger.debug("Session transport closing gracefully"); + inboundSink.tryEmitComplete(); + }); + } + + @Override + public void close() { + isClosing.set(true); + logger.debug("Session transport closed"); + } + + private void handleIncomingMessages() { + this.inboundSink.asFlux().flatMap(message -> session.handle(message)).doOnTerminate(() -> { + // The outbound processing will dispose its scheduler upon completion + this.outboundSink.tryEmitComplete(); + this.inboundScheduler.dispose(); + }).subscribe(); + } + + /** + * Starts the inbound processing thread that reads JSON-RPC messages from stdin. + * Messages are deserialized and passed to the session for handling. + */ + private void startInboundProcessing() { + if (isStarted.compareAndSet(false, true)) { + this.inboundScheduler.schedule(() -> { + inboundReady.tryEmitValue(null); + BufferedReader reader = null; + try { + reader = new BufferedReader(new InputStreamReader(inputStream)); + while (!isClosing.get()) { + try { + String line = reader.readLine(); + if (line == null || isClosing.get()) { + break; + } + + logger.debug("Received JSON message: {}", line); + + try { + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, + line); + if (!this.inboundSink.tryEmitNext(message).isSuccess()) { + // logIfNotClosing("Failed to enqueue message"); + break; + } + + } + catch (Exception e) { + logIfNotClosing("Error processing inbound message", e); + break; + } + } + catch (IOException e) { + logIfNotClosing("Error reading from stdin", e); + break; + } + } + } + catch (Exception e) { + logIfNotClosing("Error in inbound processing", e); + } + finally { + isClosing.set(true); + if (session != null) { + session.close(); + } + inboundSink.tryEmitComplete(); + } + }); + } + } + + /** + * Starts the outbound processing thread that writes JSON-RPC messages to stdout. + * Messages are serialized to JSON and written with a newline delimiter. + */ + private void startOutboundProcessing() { + Function, Flux> outboundConsumer = messages -> messages // @formatter:off + .doOnSubscribe(subscription -> outboundReady.tryEmitValue(null)) + .publishOn(outboundScheduler) + .handle((message, sink) -> { + if (message != null && !isClosing.get()) { + try { + String jsonMessage = objectMapper.writeValueAsString(message); + // Escape any embedded newlines in the JSON message as per spec + jsonMessage = jsonMessage.replace("\r\n", "\\n").replace("\n", "\\n").replace("\r", "\\n"); + + synchronized (outputStream) { + outputStream.write(jsonMessage.getBytes(StandardCharsets.UTF_8)); + outputStream.write("\n".getBytes(StandardCharsets.UTF_8)); + outputStream.flush(); + } + sink.next(message); + } + catch (IOException e) { + if (!isClosing.get()) { + logger.error("Error writing message", e); + sink.error(new RuntimeException(e)); + } + else { + logger.debug("Stream closed during shutdown", e); + } + } + } + else if (isClosing.get()) { + sink.complete(); + } + }) + .doOnComplete(() -> { + isClosing.set(true); + outboundScheduler.dispose(); + }) + .doOnError(e -> { + if (!isClosing.get()) { + logger.error("Error in outbound processing", e); + isClosing.set(true); + outboundScheduler.dispose(); + } + }) + .map(msg -> (JSONRPCMessage) msg); + + outboundConsumer.apply(outboundSink.asFlux()).subscribe(); + } // @formatter:on + + private void logIfNotClosing(String message, Exception e) { + if (!isClosing.get()) { + logger.error(message, e); + } + } + + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java new file mode 100644 index 00000000..b9a19de6 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java @@ -0,0 +1,466 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.List; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Test suite for the {@link McpAsyncServer} that can be used with different + * {@link McpTransport} implementations. + * + * @author Christian Tzolov + */ +// KEEP IN SYNC with the class in mcp-test module +@Deprecated +public abstract class AbstractMcpAsyncServerDeprecatedTests { + + private static final String TEST_TOOL_NAME = "test-tool"; + + private static final String TEST_RESOURCE_URI = "test://resource"; + + private static final String TEST_PROMPT_NAME = "test-prompt"; + + abstract protected ServerMcpTransport createMcpTransport(); + + protected void onStart() { + } + + protected void onClose() { + } + + @BeforeEach + void setUp() { + } + + @AfterEach + void tearDown() { + onClose(); + } + + // --------------------------------------- + // Server Lifecycle Tests + // --------------------------------------- + + @Test + void testConstructorWithInvalidArguments() { + assertThatThrownBy(() -> McpServer.async((ServerMcpTransport) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport must not be null"); + + assertThatThrownBy(() -> McpServer.async(createMcpTransport()).serverInfo((McpSchema.Implementation) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Server info must not be null"); + } + + @Test + void testGracefulShutdown() { + var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(mcpAsyncServer.closeGracefully()).verifyComplete(); + } + + @Test + void testImmediateClose() { + var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @Test + void testAddTool() { + Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(newTool, + args -> Mono.just(new CallToolResult(List.of(), false))))) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddDuplicateTool() { + Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(duplicateTool, args -> Mono.just(new CallToolResult(List.of(), false))) + .build(); + + StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(duplicateTool, + args -> Mono.just(new CallToolResult(List.of(), false))))) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); + }); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveTool() { + Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(too, args -> Mono.just(new CallToolResult(List.of(), false))) + .build(); + + StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentTool() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.removeTool("nonexistent-tool")).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class).hasMessage("Tool with name 'nonexistent-tool' not found"); + }); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testNotifyToolsListChanged() { + Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(too, args -> Mono.just(new CallToolResult(List.of(), false))) + .build(); + + StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resources Tests + // --------------------------------------- + + @Test + void testNotifyResourcesListChanged() { + var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(mcpAsyncServer.notifyResourcesListChanged()).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResource() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", + null); + McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration( + resource, req -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(mcpAsyncServer.addResource(registration)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithNullRegistration() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceRegistration) null)) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class).hasMessage("Resource must not be null"); + }); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .build(); + + Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", + null); + McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration( + resource, req -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(serverWithoutResources.addResource(registration)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with resource capabilities"); + }); + } + + @Test + void testRemoveResourceWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .build(); + + StepVerifier.create(serverWithoutResources.removeResource(TEST_RESOURCE_URI)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with resource capabilities"); + }); + } + + // --------------------------------------- + // Prompts Tests + // --------------------------------------- + + @Test + void testNotifyPromptsListChanged() { + var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(mcpAsyncServer.notifyPromptsListChanged()).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddPromptWithNullRegistration() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(false).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptRegistration) null)) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt registration must not be null"); + }); + } + + @Test + void testAddPromptWithoutCapability() { + // Create a server without prompt capabilities + McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .build(); + + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); + McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt, + req -> Mono.just(new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); + + StepVerifier.create(serverWithoutPrompts.addPrompt(registration)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with prompt capabilities"); + }); + } + + @Test + void testRemovePromptWithoutCapability() { + // Create a server without prompt capabilities + McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .build(); + + StepVerifier.create(serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with prompt capabilities"); + }); + } + + @Test + void testRemovePrompt() { + String TEST_PROMPT_NAME_TO_REMOVE = "TEST_PROMPT_NAME678"; + + Prompt prompt = new Prompt(TEST_PROMPT_NAME_TO_REMOVE, "Test Prompt", List.of()); + McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt, + req -> Mono.just(new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); + + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .prompts(registration) + .build(); + + StepVerifier.create(mcpAsyncServer.removePrompt(TEST_PROMPT_NAME_TO_REMOVE)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentPrompt() { + var mcpAsyncServer2 = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer2.removePrompt("nonexistent-prompt")).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Prompt with name 'nonexistent-prompt' not found"); + }); + + assertThatCode(() -> mcpAsyncServer2.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + + @Test + void testRootsChangeConsumers() { + // Test with single consumer + var rootsReceived = new McpSchema.Root[1]; + var consumerCalled = new boolean[1]; + + var singleConsumerServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> { + consumerCalled[0] = true; + if (!roots.isEmpty()) { + rootsReceived[0] = roots.get(0); + } + }))) + .build(); + + assertThat(singleConsumerServer).isNotNull(); + assertThatCode(() -> singleConsumerServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + onClose(); + + // Test with multiple consumers + var consumer1Called = new boolean[1]; + var consumer2Called = new boolean[1]; + var rootsContent = new List[1]; + + var multipleConsumersServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> { + consumer1Called[0] = true; + rootsContent[0] = roots; + }), roots -> Mono.fromRunnable(() -> consumer2Called[0] = true))) + .build(); + + assertThat(multipleConsumersServer).isNotNull(); + assertThatCode(() -> multipleConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + onClose(); + + // Test error handling + var errorHandlingServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .rootsChangeConsumers(List.of(roots -> { + throw new RuntimeException("Test error"); + })) + .build(); + + assertThat(errorHandlingServer).isNotNull(); + assertThatCode(() -> errorHandlingServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + onClose(); + + // Test without consumers + var noConsumersServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThat(noConsumersServer).isNotNull(); + assertThatCode(() -> noConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + + @Test + void testLoggingLevels() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().build()) + .build(); + + // Test all logging levels + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + var notification = McpSchema.LoggingMessageNotification.builder() + .level(level) + .logger("test-logger") + .data("Test message with level " + level) + .build(); + + StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); + } + } + + @Test + void testLoggingWithoutCapability() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().build()) // No logging capability + .build(); + + var notification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Test log message") + .build(); + + StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); + } + + @Test + void testLoggingWithNullNotification() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().build()) + .build(); + + StepVerifier.create(mcpAsyncServer.loggingNotification(null)).verifyError(McpError.class); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index 568a655d..4b4fc434 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -17,12 +17,10 @@ import io.modelcontextprotocol.spec.McpSchema.Resource; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -32,11 +30,10 @@ /** * Test suite for the {@link McpAsyncServer} that can be used with different - * {@link McpTransport} implementations. + * {@link McpTransportProvider} implementations. * * @author Christian Tzolov */ -// KEEP IN SYNC with the class in mcp-test module public abstract class AbstractMcpAsyncServerTests { private static final String TEST_TOOL_NAME = "test-tool"; @@ -45,7 +42,7 @@ public abstract class AbstractMcpAsyncServerTests { private static final String TEST_PROMPT_NAME = "test-prompt"; - abstract protected ServerMcpTransport createMcpTransport(); + abstract protected McpServerTransportProvider createMcpTransportProvider(); protected void onStart() { } @@ -68,25 +65,26 @@ void tearDown() { @Test void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.async((ServerMcpTransport) null)) + assertThatThrownBy(() -> McpServer.async((McpServerTransportProvider) null)) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport must not be null"); + .hasMessage("Transport provider must not be null"); - assertThatThrownBy(() -> McpServer.async(createMcpTransport()).serverInfo((McpSchema.Implementation) null)) + assertThatThrownBy( + () -> McpServer.async(createMcpTransportProvider()).serverInfo((McpSchema.Implementation) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Server info must not be null"); } @Test void testGracefulShutdown() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.closeGracefully()).verifyComplete(); } @Test void testImmediateClose() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException(); } @@ -105,13 +103,13 @@ void testImmediateClose() { @Test void testAddTool() { Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(newTool, - args -> Mono.just(new CallToolResult(List.of(), false))))) + StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, + (excnage, args) -> Mono.just(new CallToolResult(List.of(), false))))) .verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); @@ -121,14 +119,15 @@ void testAddTool() { void testAddDuplicateTool() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, args -> Mono.just(new CallToolResult(List.of(), false))) + .tool(duplicateTool, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) .build(); - StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(duplicateTool, - args -> Mono.just(new CallToolResult(List.of(), false))))) + StepVerifier + .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(duplicateTool, + (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) .verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); @@ -141,10 +140,10 @@ void testAddDuplicateTool() { void testRemoveTool() { Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, args -> Mono.just(new CallToolResult(List.of(), false))) + .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); @@ -154,7 +153,7 @@ void testRemoveTool() { @Test void testRemoveNonexistentTool() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -170,10 +169,10 @@ void testRemoveNonexistentTool() { void testNotifyToolsListChanged() { Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, args -> Mono.just(new CallToolResult(List.of(), false))) + .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); @@ -187,7 +186,7 @@ void testNotifyToolsListChanged() { @Test void testNotifyResourcesListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.notifyResourcesListChanged()).verifyComplete(); @@ -196,29 +195,29 @@ void testNotifyResourcesListChanged() { @Test void testAddResource() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); - McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration( - resource, req -> Mono.just(new ReadResourceResult(List.of()))); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); - StepVerifier.create(mcpAsyncServer.addResource(registration)).verifyComplete(); + StepVerifier.create(mcpAsyncServer.addResource(specification)).verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } @Test - void testAddResourceWithNullRegistration() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + void testAddResourceWithNullSpecification() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); - StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceRegistration) null)) + StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceSpecification) null)) .verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class).hasMessage("Resource must not be null"); }); @@ -229,16 +228,16 @@ void testAddResourceWithNullRegistration() { @Test void testAddResourceWithoutCapability() { // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport()) + McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .build(); Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); - McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration( - resource, req -> Mono.just(new ReadResourceResult(List.of()))); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); - StepVerifier.create(serverWithoutResources.addResource(registration)).verifyErrorSatisfies(error -> { + StepVerifier.create(serverWithoutResources.addResource(specification)).verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) .hasMessage("Server must be configured with resource capabilities"); }); @@ -247,7 +246,7 @@ void testAddResourceWithoutCapability() { @Test void testRemoveResourceWithoutCapability() { // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport()) + McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .build(); @@ -263,7 +262,7 @@ void testRemoveResourceWithoutCapability() { @Test void testNotifyPromptsListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.notifyPromptsListChanged()).verifyComplete(); @@ -271,31 +270,31 @@ void testNotifyPromptsListChanged() { } @Test - void testAddPromptWithNullRegistration() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + void testAddPromptWithNullSpecification() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(false).build()) .build(); - StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptRegistration) null)) + StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptSpecification) null)) .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt registration must not be null"); + assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt specification must not be null"); }); } @Test void testAddPromptWithoutCapability() { // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport()) + McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .build(); Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt, - req -> Mono.just(new GetPromptResult("Test prompt description", List + McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( + prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - StepVerifier.create(serverWithoutPrompts.addPrompt(registration)).verifyErrorSatisfies(error -> { + StepVerifier.create(serverWithoutPrompts.addPrompt(specification)).verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) .hasMessage("Server must be configured with prompt capabilities"); }); @@ -304,7 +303,7 @@ void testAddPromptWithoutCapability() { @Test void testRemovePromptWithoutCapability() { // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport()) + McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .build(); @@ -319,14 +318,14 @@ void testRemovePrompt() { String TEST_PROMPT_NAME_TO_REMOVE = "TEST_PROMPT_NAME678"; Prompt prompt = new Prompt(TEST_PROMPT_NAME_TO_REMOVE, "Test Prompt", List.of()); - McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt, - req -> Mono.just(new GetPromptResult("Test prompt description", List + McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( + prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) - .prompts(registration) + .prompts(specification) .build(); StepVerifier.create(mcpAsyncServer.removePrompt(TEST_PROMPT_NAME_TO_REMOVE)).verifyComplete(); @@ -336,7 +335,7 @@ void testRemovePrompt() { @Test void testRemoveNonexistentPrompt() { - var mcpAsyncServer2 = McpServer.async(createMcpTransport()) + var mcpAsyncServer2 = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); @@ -355,14 +354,14 @@ void testRemoveNonexistentPrompt() { // --------------------------------------- @Test - void testRootsChangeConsumers() { + void testRootsChangeHandlers() { // Test with single consumer var rootsReceived = new McpSchema.Root[1]; var consumerCalled = new boolean[1]; - var singleConsumerServer = McpServer.async(createMcpTransport()) + var singleConsumerServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> { + .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { consumerCalled[0] = true; if (!roots.isEmpty()) { rootsReceived[0] = roots.get(0); @@ -380,12 +379,12 @@ void testRootsChangeConsumers() { var consumer2Called = new boolean[1]; var rootsContent = new List[1]; - var multipleConsumersServer = McpServer.async(createMcpTransport()) + var multipleConsumersServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> { + .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { consumer1Called[0] = true; rootsContent[0] = roots; - }), roots -> Mono.fromRunnable(() -> consumer2Called[0] = true))) + }), (exchange, roots) -> Mono.fromRunnable(() -> consumer2Called[0] = true))) .build(); assertThat(multipleConsumersServer).isNotNull(); @@ -394,9 +393,9 @@ void testRootsChangeConsumers() { onClose(); // Test error handling - var errorHandlingServer = McpServer.async(createMcpTransport()) + var errorHandlingServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { + .rootsChangeHandlers(List.of((exchange, roots) -> { throw new RuntimeException("Test error"); })) .build(); @@ -407,7 +406,9 @@ void testRootsChangeConsumers() { onClose(); // Test without consumers - var noConsumersServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var noConsumersServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); assertThat(noConsumersServer).isNotNull(); assertThatCode(() -> noConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) @@ -420,7 +421,7 @@ void testRootsChangeConsumers() { @Test void testLoggingLevels() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().logging().build()) .build(); @@ -439,7 +440,7 @@ void testLoggingLevels() { @Test void testLoggingWithoutCapability() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().build()) // No logging capability .build(); @@ -455,7 +456,7 @@ void testLoggingWithoutCapability() { @Test void testLoggingWithNullNotification() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().logging().build()) .build(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java new file mode 100644 index 00000000..16bc2d6e --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java @@ -0,0 +1,433 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.util.List; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Test suite for the {@link McpSyncServer} that can be used with different + * {@link McpTransport} implementations. + * + * @author Christian Tzolov + */ +// KEEP IN SYNC with the class in mcp-test module +@Deprecated +public abstract class AbstractMcpSyncServerDeprecatedTests { + + private static final String TEST_TOOL_NAME = "test-tool"; + + private static final String TEST_RESOURCE_URI = "test://resource"; + + private static final String TEST_PROMPT_NAME = "test-prompt"; + + abstract protected ServerMcpTransport createMcpTransport(); + + protected void onStart() { + } + + protected void onClose() { + } + + @BeforeEach + void setUp() { + // onStart(); + } + + @AfterEach + void tearDown() { + onClose(); + } + + // --------------------------------------- + // Server Lifecycle Tests + // --------------------------------------- + + @Test + void testConstructorWithInvalidArguments() { + assertThatThrownBy(() -> McpServer.sync((ServerMcpTransport) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport must not be null"); + + assertThatThrownBy(() -> McpServer.sync(createMcpTransport()).serverInfo(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Server info must not be null"); + } + + @Test + void testGracefulShutdown() { + var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testImmediateClose() { + var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.close()).doesNotThrowAnyException(); + } + + @Test + void testGetAsyncServer() { + var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @Test + void testAddTool() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); + assertThatCode(() -> mcpSyncServer + .addTool(new McpServerFeatures.SyncToolRegistration(newTool, args -> new CallToolResult(List.of(), false)))) + .doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddDuplicateTool() { + Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(duplicateTool, args -> new CallToolResult(List.of(), false)) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolRegistration(duplicateTool, + args -> new CallToolResult(List.of(), false)))) + .isInstanceOf(McpError.class) + .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testRemoveTool() { + Tool tool = new McpSchema.Tool(TEST_TOOL_NAME, "Test tool", emptyJsonSchema); + + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(tool, args -> new CallToolResult(List.of(), false)) + .build(); + + assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentTool() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.removeTool("nonexistent-tool")).isInstanceOf(McpError.class) + .hasMessage("Tool with name 'nonexistent-tool' not found"); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testNotifyToolsListChanged() { + var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.notifyToolsListChanged()).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resources Tests + // --------------------------------------- + + @Test + void testNotifyResourcesListChanged() { + var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.notifyResourcesListChanged()).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddResource() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", + null); + McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration( + resource, req -> new ReadResourceResult(List.of())); + + assertThatCode(() -> mcpSyncServer.addResource(registration)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithNullRegistration() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceRegistration) null)) + .isInstanceOf(McpError.class) + .hasMessage("Resource must not be null"); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithoutCapability() { + var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", + null); + McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration( + resource, req -> new ReadResourceResult(List.of())); + + assertThatThrownBy(() -> serverWithoutResources.addResource(registration)).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with resource capabilities"); + } + + @Test + void testRemoveResourceWithoutCapability() { + var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with resource capabilities"); + } + + // --------------------------------------- + // Prompts Tests + // --------------------------------------- + + @Test + void testNotifyPromptsListChanged() { + var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.notifyPromptsListChanged()).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddPromptWithNullRegistration() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(false).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptRegistration) null)) + .isInstanceOf(McpError.class) + .hasMessage("Prompt registration must not be null"); + } + + @Test + void testAddPromptWithoutCapability() { + var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); + McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt, + req -> new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); + + assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(registration)).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with prompt capabilities"); + } + + @Test + void testRemovePromptWithoutCapability() { + var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with prompt capabilities"); + } + + @Test + void testRemovePrompt() { + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); + McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt, + req -> new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); + + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .prompts(registration) + .build(); + + assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentPrompt() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.removePrompt("nonexistent-prompt")).isInstanceOf(McpError.class) + .hasMessage("Prompt with name 'nonexistent-prompt' not found"); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + + @Test + void testRootsChangeConsumers() { + // Test with single consumer + var rootsReceived = new McpSchema.Root[1]; + var consumerCalled = new boolean[1]; + + var singleConsumerServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .rootsChangeConsumers(List.of(roots -> { + consumerCalled[0] = true; + if (!roots.isEmpty()) { + rootsReceived[0] = roots.get(0); + } + })) + .build(); + + assertThat(singleConsumerServer).isNotNull(); + assertThatCode(() -> singleConsumerServer.closeGracefully()).doesNotThrowAnyException(); + onClose(); + + // Test with multiple consumers + var consumer1Called = new boolean[1]; + var consumer2Called = new boolean[1]; + var rootsContent = new List[1]; + + var multipleConsumersServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .rootsChangeConsumers(List.of(roots -> { + consumer1Called[0] = true; + rootsContent[0] = roots; + }, roots -> consumer2Called[0] = true)) + .build(); + + assertThat(multipleConsumersServer).isNotNull(); + assertThatCode(() -> multipleConsumersServer.closeGracefully()).doesNotThrowAnyException(); + onClose(); + + // Test error handling + var errorHandlingServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .rootsChangeConsumers(List.of(roots -> { + throw new RuntimeException("Test error"); + })) + .build(); + + assertThat(errorHandlingServer).isNotNull(); + assertThatCode(() -> errorHandlingServer.closeGracefully()).doesNotThrowAnyException(); + onClose(); + + // Test without consumers + var noConsumersServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThat(noConsumersServer).isNotNull(); + assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + + @Test + void testLoggingLevels() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().build()) + .build(); + + // Test all logging levels + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + var notification = McpSchema.LoggingMessageNotification.builder() + .level(level) + .logger("test-logger") + .data("Test message with level " + level) + .build(); + + assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); + } + } + + @Test + void testLoggingWithoutCapability() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().build()) // No logging capability + .build(); + + var notification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Test log message") + .build(); + + assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); + } + + @Test + void testLoggingWithNullNotification() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.loggingNotification(null)).isInstanceOf(McpError.class); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index d76cf8e5..17feb36e 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -16,8 +16,7 @@ import io.modelcontextprotocol.spec.McpSchema.Resource; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -28,11 +27,10 @@ /** * Test suite for the {@link McpSyncServer} that can be used with different - * {@link McpTransport} implementations. + * {@link McpTransportProvider} implementations. * * @author Christian Tzolov */ -// KEEP IN SYNC with the class in mcp-test module public abstract class AbstractMcpSyncServerTests { private static final String TEST_TOOL_NAME = "test-tool"; @@ -41,7 +39,7 @@ public abstract class AbstractMcpSyncServerTests { private static final String TEST_PROMPT_NAME = "test-prompt"; - abstract protected ServerMcpTransport createMcpTransport(); + abstract protected McpServerTransportProvider createMcpTransportProvider(); protected void onStart() { } @@ -65,31 +63,32 @@ void tearDown() { @Test void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.sync((ServerMcpTransport) null)).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport must not be null"); + assertThatThrownBy(() -> McpServer.sync((McpServerTransportProvider) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport provider must not be null"); - assertThatThrownBy(() -> McpServer.sync(createMcpTransport()).serverInfo(null)) + assertThatThrownBy(() -> McpServer.sync(createMcpTransportProvider()).serverInfo(null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Server info must not be null"); } @Test void testGracefulShutdown() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); } @Test void testImmediateClose() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.close()).doesNotThrowAnyException(); } @Test void testGetAsyncServer() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); @@ -110,14 +109,14 @@ void testGetAsyncServer() { @Test void testAddTool() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - assertThatCode(() -> mcpSyncServer - .addTool(new McpServerFeatures.SyncToolRegistration(newTool, args -> new CallToolResult(List.of(), false)))) + assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(newTool, + (exchange, args) -> new CallToolResult(List.of(), false)))) .doesNotThrowAnyException(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); @@ -127,14 +126,14 @@ void testAddTool() { void testAddDuplicateTool() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, args -> new CallToolResult(List.of(), false)) + .tool(duplicateTool, (exchange, args) -> new CallToolResult(List.of(), false)) .build(); - assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolRegistration(duplicateTool, - args -> new CallToolResult(List.of(), false)))) + assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool, + (exchange, args) -> new CallToolResult(List.of(), false)))) .isInstanceOf(McpError.class) .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); @@ -145,10 +144,10 @@ void testAddDuplicateTool() { void testRemoveTool() { Tool tool = new McpSchema.Tool(TEST_TOOL_NAME, "Test tool", emptyJsonSchema); - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(tool, args -> new CallToolResult(List.of(), false)) + .tool(tool, (exchange, args) -> new CallToolResult(List.of(), false)) .build(); assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); @@ -158,7 +157,7 @@ void testRemoveTool() { @Test void testRemoveNonexistentTool() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -171,7 +170,7 @@ void testRemoveNonexistentTool() { @Test void testNotifyToolsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.notifyToolsListChanged()).doesNotThrowAnyException(); @@ -184,7 +183,7 @@ void testNotifyToolsListChanged() { @Test void testNotifyResourcesListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.notifyResourcesListChanged()).doesNotThrowAnyException(); @@ -193,29 +192,29 @@ void testNotifyResourcesListChanged() { @Test void testAddResource() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); - McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration( - resource, req -> new ReadResourceResult(List.of())); + McpServerFeatures.SyncResourceSpecification specificaiton = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); - assertThatCode(() -> mcpSyncServer.addResource(registration)).doesNotThrowAnyException(); + assertThatCode(() -> mcpSyncServer.addResource(specificaiton)).doesNotThrowAnyException(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); } @Test - void testAddResourceWithNullRegistration() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + void testAddResourceWithNullSpecifiation() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceRegistration) null)) + assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceSpecification) null)) .isInstanceOf(McpError.class) .hasMessage("Resource must not be null"); @@ -224,20 +223,24 @@ void testAddResourceWithNullRegistration() { @Test void testAddResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var serverWithoutResources = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); - McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration( - resource, req -> new ReadResourceResult(List.of())); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); - assertThatThrownBy(() -> serverWithoutResources.addResource(registration)).isInstanceOf(McpError.class) + assertThatThrownBy(() -> serverWithoutResources.addResource(specification)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with resource capabilities"); } @Test void testRemoveResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var serverWithoutResources = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with resource capabilities"); @@ -249,7 +252,7 @@ void testRemoveResourceWithoutCapability() { @Test void testNotifyPromptsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.notifyPromptsListChanged()).doesNotThrowAnyException(); @@ -257,33 +260,37 @@ void testNotifyPromptsListChanged() { } @Test - void testAddPromptWithNullRegistration() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + void testAddPromptWithNullSpecification() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(false).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptRegistration) null)) + assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptSpecification) null)) .isInstanceOf(McpError.class) - .hasMessage("Prompt registration must not be null"); + .hasMessage("Prompt specification must not be null"); } @Test void testAddPromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt, - req -> new GetPromptResult("Test prompt description", List + McpServerFeatures.SyncPromptSpecification specificaiton = new McpServerFeatures.SyncPromptSpecification(prompt, + (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(registration)).isInstanceOf(McpError.class) + assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specificaiton)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with prompt capabilities"); } @Test void testRemovePromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with prompt capabilities"); @@ -292,14 +299,14 @@ void testRemovePromptWithoutCapability() { @Test void testRemovePrompt() { Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt, - req -> new GetPromptResult("Test prompt description", List + McpServerFeatures.SyncPromptSpecification specificaiton = new McpServerFeatures.SyncPromptSpecification(prompt, + (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) - .prompts(registration) + .prompts(specificaiton) .build(); assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); @@ -309,7 +316,7 @@ void testRemovePrompt() { @Test void testRemoveNonexistentPrompt() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); @@ -325,14 +332,14 @@ void testRemoveNonexistentPrompt() { // --------------------------------------- @Test - void testRootsChangeConsumers() { + void testRootsChangeHandlers() { // Test with single consumer var rootsReceived = new McpSchema.Root[1]; var consumerCalled = new boolean[1]; - var singleConsumerServer = McpServer.sync(createMcpTransport()) + var singleConsumerServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { + .rootsChangeHandlers(List.of((exchage, roots) -> { consumerCalled[0] = true; if (!roots.isEmpty()) { rootsReceived[0] = roots.get(0); @@ -349,12 +356,12 @@ void testRootsChangeConsumers() { var consumer2Called = new boolean[1]; var rootsContent = new List[1]; - var multipleConsumersServer = McpServer.sync(createMcpTransport()) + var multipleConsumersServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { + .rootsChangeHandlers(List.of((exchange, roots) -> { consumer1Called[0] = true; rootsContent[0] = roots; - }, roots -> consumer2Called[0] = true)) + }, (exchange, roots) -> consumer2Called[0] = true)) .build(); assertThat(multipleConsumersServer).isNotNull(); @@ -362,9 +369,9 @@ void testRootsChangeConsumers() { onClose(); // Test error handling - var errorHandlingServer = McpServer.sync(createMcpTransport()) + var errorHandlingServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { + .rootsChangeHandlers(List.of((exchange, roots) -> { throw new RuntimeException("Test error"); })) .build(); @@ -374,7 +381,7 @@ void testRootsChangeConsumers() { onClose(); // Test without consumers - var noConsumersServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var noConsumersServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThat(noConsumersServer).isNotNull(); assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); @@ -386,7 +393,7 @@ void testRootsChangeConsumers() { @Test void testLoggingLevels() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().logging().build()) .build(); @@ -405,7 +412,7 @@ void testLoggingLevels() { @Test void testLoggingWithoutCapability() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().build()) // No logging capability .build(); @@ -421,7 +428,7 @@ void testLoggingWithoutCapability() { @Test void testLoggingWithNullNotification() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().logging().build()) .build(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java index 715f636d..7f7357a6 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java @@ -15,7 +15,7 @@ * @author Christian Tzolov */ @Timeout(15) // Giving extra time beyond the client timeout -class ServletSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { +class ServletSseMcpAsyncServerTests extends AbstractMcpAsyncServerDeprecatedTests { @Override protected ServerMcpTransport createMcpTransport() { diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java index 208de7f7..4507256c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java @@ -15,7 +15,7 @@ * @author Christian Tzolov */ @Timeout(15) // Giving extra time beyond the client timeout -class ServletSseMcpSyncServerTests extends AbstractMcpSyncServerTests { +class ServletSseMcpSyncServerTests extends AbstractMcpSyncServerDeprecatedTests { @Override protected ServerMcpTransport createMcpTransport() { diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerDeprecatedTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerDeprecatedTests.java new file mode 100644 index 00000000..db95db07 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerDeprecatedTests.java @@ -0,0 +1,25 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.server.transport.StdioServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.Timeout; + +/** + * Tests for {@link McpAsyncServer} using {@link StdioServerTransport}. + * + * @author Christian Tzolov + */ +@Deprecated +@Timeout(15) // Giving extra time beyond the client timeout +class StdioMcpAsyncServerDeprecatedTests extends AbstractMcpAsyncServerDeprecatedTests { + + @Override + protected ServerMcpTransport createMcpTransport() { + return new StdioServerTransport(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java index e933d638..27ff53c9 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java @@ -5,7 +5,8 @@ package io.modelcontextprotocol.server; import io.modelcontextprotocol.server.transport.StdioServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.StdioServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; /** @@ -17,8 +18,8 @@ class StdioMcpAsyncServerTests extends AbstractMcpAsyncServerTests { @Override - protected ServerMcpTransport createMcpTransport() { - return new StdioServerTransport(); + protected McpServerTransportProvider createMcpTransportProvider() { + return new StdioServerTransportProvider(); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerDeprecatedTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerDeprecatedTests.java new file mode 100644 index 00000000..149f7281 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerDeprecatedTests.java @@ -0,0 +1,25 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.server.transport.StdioServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.Timeout; + +/** + * Tests for {@link McpSyncServer} using {@link StdioServerTransport}. + * + * @author Christian Tzolov + */ +@Deprecated +@Timeout(15) // Giving extra time beyond the client timeout +class StdioMcpSyncServerDeprecatedTests extends AbstractMcpSyncServerDeprecatedTests { + + @Override + protected ServerMcpTransport createMcpTransport() { + return new StdioServerTransport(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java index d9350417..a71c3849 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java @@ -4,12 +4,12 @@ package io.modelcontextprotocol.server; -import io.modelcontextprotocol.server.transport.StdioServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.StdioServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; /** - * Tests for {@link McpSyncServer} using {@link StdioServerTransport}. + * Tests for {@link McpSyncServer} using {@link StdioServerTransportProvider}. * * @author Christian Tzolov */ @@ -17,8 +17,8 @@ class StdioMcpSyncServerTests extends AbstractMcpSyncServerTests { @Override - protected ServerMcpTransport createMcpTransport() { - return new StdioServerTransport(); + protected McpServerTransportProvider createMcpTransportProvider() { + return new StdioServerTransportProvider(); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java deleted file mode 100644 index 0ab72a99..00000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java +++ /dev/null @@ -1,69 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -*/ -package io.modelcontextprotocol.server.transport; - -import java.io.IOException; -import java.io.InputStream; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.LinkedBlockingQueue; - -public class BlockingInputStream extends InputStream { - - private final BlockingQueue queue = new LinkedBlockingQueue<>(); - - private volatile boolean completed = false; - - private volatile boolean closed = false; - - @Override - public int read() throws IOException { - if (closed) { - throw new IOException("Stream is closed"); - } - - try { - Integer value = queue.poll(); - if (value == null) { - if (completed) { - return -1; - } - value = queue.take(); // Blocks until data is available - if (value == null && completed) { - return -1; - } - } - return value; - } - catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new IOException("Read interrupted", e); - } - } - - public void write(int b) { - if (!closed && !completed) { - queue.offer(b); - } - } - - public void write(byte[] data) { - if (!closed && !completed) { - for (byte b : data) { - queue.offer((int) b & 0xFF); - } - } - } - - public void complete() { - this.completed = true; - } - - @Override - public void close() { - this.closed = true; - this.completed = true; - this.queue.clear(); - } - -} \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java new file mode 100644 index 00000000..14987b5a --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java @@ -0,0 +1,227 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.InputStream; +import java.io.PrintStream; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link StdioServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Disabled +class StdioServerTransportProviderTests { + + private final PrintStream originalOut = System.out; + + private final PrintStream originalErr = System.err; + + private ByteArrayOutputStream testErr; + + private PrintStream testOutPrintStream; + + private StdioServerTransportProvider transportProvider; + + private ObjectMapper objectMapper; + + private McpServerSession.Factory sessionFactory; + + private McpServerSession mockSession; + + @BeforeEach + void setUp() { + testErr = new ByteArrayOutputStream(); + + testOutPrintStream = new PrintStream(testErr, true); + System.setOut(testOutPrintStream); + System.setErr(testOutPrintStream); + + objectMapper = new ObjectMapper(); + + // Create mocks for session factory and session + mockSession = mock(McpServerSession.class); + sessionFactory = mock(McpServerSession.Factory.class); + + // Configure mock behavior + when(sessionFactory.create(any(McpServerTransport.class))).thenReturn(mockSession); + when(mockSession.closeGracefully()).thenReturn(Mono.empty()); + when(mockSession.sendNotification(any(), any())).thenReturn(Mono.empty()); + + transportProvider = new StdioServerTransportProvider(objectMapper, System.in, testOutPrintStream); + } + + @AfterEach + void tearDown() { + if (transportProvider != null) { + transportProvider.closeGracefully().block(); + } + if (testOutPrintStream != null) { + testOutPrintStream.close(); + } + System.setOut(originalOut); + System.setErr(originalErr); + } + + @Test + void shouldCreateSessionWhenSessionFactoryIsSet() { + // Set session factory + transportProvider.setSessionFactory(sessionFactory); + + // Verify session was created with a transport + assertThat(testErr.toString()).doesNotContain("Error"); + } + + @Test + void shouldHandleIncomingMessages() throws Exception { + + String jsonMessage = "{\"jsonrpc\":\"2.0\",\"method\":\"test\",\"params\":{},\"id\":1}\n"; + InputStream stream = new ByteArrayInputStream(jsonMessage.getBytes(StandardCharsets.UTF_8)); + + transportProvider = new StdioServerTransportProvider(objectMapper, stream, System.out); + // Set up a real session to capture the message + AtomicReference capturedMessage = new AtomicReference<>(); + CountDownLatch messageLatch = new CountDownLatch(1); + + McpServerSession.Factory realSessionFactory = transport -> { + McpServerSession session = mock(McpServerSession.class); + when(session.handle(any())).thenAnswer(invocation -> { + capturedMessage.set(invocation.getArgument(0)); + messageLatch.countDown(); + return Mono.empty(); + }); + when(session.closeGracefully()).thenReturn(Mono.empty()); + return session; + }; + + // Set session factory + transportProvider.setSessionFactory(realSessionFactory); + + // Wait for the message to be processed using the latch + StepVerifier.create(Mono.fromCallable(() -> messageLatch.await(100, TimeUnit.SECONDS)).flatMap(success -> { + if (!success) { + return Mono.error(new AssertionError("Timeout waiting for message processing")); + } + return Mono.just(capturedMessage.get()); + })).assertNext(message -> { + assertThat(message).isNotNull(); + assertThat(message).isInstanceOf(McpSchema.JSONRPCRequest.class); + McpSchema.JSONRPCRequest request = (McpSchema.JSONRPCRequest) message; + assertThat(request.method()).isEqualTo("test"); + assertThat(request.id()).isEqualTo(1); + }).verifyComplete(); + } + + @Test + void shouldNotifyClients() { + // Set session factory + transportProvider.setSessionFactory(sessionFactory); + + // Send notification + String method = "testNotification"; + Map params = Map.of("key", "value"); + + StepVerifier.create(transportProvider.notifyClients(method, params)).verifyComplete(); + + // Error log should be empty + assertThat(testErr.toString()).doesNotContain("Error"); + } + + @Test + void shouldCloseGracefully() { + // Set session factory + transportProvider.setSessionFactory(sessionFactory); + + // Close gracefully + StepVerifier.create(transportProvider.closeGracefully()).verifyComplete(); + + // Error log should be empty + assertThat(testErr.toString()).doesNotContain("Error"); + } + + @Test + void shouldHandleMultipleCloseGracefullyCalls() { + // Set session factory + transportProvider.setSessionFactory(sessionFactory); + + // Close gracefully multiple times + StepVerifier + .create(transportProvider.closeGracefully() + .then(transportProvider.closeGracefully()) + .then(transportProvider.closeGracefully())) + .verifyComplete(); + + // Error log should be empty + assertThat(testErr.toString()).doesNotContain("Error"); + } + + @Test + void shouldHandleNotificationBeforeSessionFactoryIsSet() { + + transportProvider = new StdioServerTransportProvider(objectMapper); + // Send notification before setting session factory + StepVerifier.create(transportProvider.notifyClients("testNotification", Map.of("key", "value"))) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class); + }); + } + + @Test + void shouldHandleInvalidJsonMessage() throws Exception { + + // Write an invalid JSON message to the input stream + String jsonMessage = "{invalid json}\n"; + InputStream stream = new ByteArrayInputStream(jsonMessage.getBytes(StandardCharsets.UTF_8)); + + transportProvider = new StdioServerTransportProvider(objectMapper, stream, testOutPrintStream); + + // Set up a session factory + transportProvider.setSessionFactory(sessionFactory); + + // Use StepVerifier with a timeout to wait for the error to be processed + StepVerifier + .create(Mono.delay(java.time.Duration.ofMillis(500)).then(Mono.fromCallable(() -> testErr.toString()))) + .assertNext(errorOutput -> assertThat(errorOutput).contains("Error processing inbound message")) + .verifyComplete(); + } + + @Test + void shouldHandleSessionClose() throws Exception { + // Set session factory + transportProvider.setSessionFactory(sessionFactory); + + // Close the transport provider + transportProvider.close(); + + // Verify session was closed + verify(mockSession).closeGracefully(); + } + +} From 95a02f8fb95d4cd2aae5efb6381443d06f81568f Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Mon, 17 Mar 2025 18:51:35 +0100 Subject: [PATCH 12/20] feat: add WebMvcSseServerTransportProvider implementation This commit introduces WebMvcSseServerTransportProvider as a replacement for the now-deprecated WebMvcSseServerTransport. The new provider-based implementation offers improved session management and better alignment with the MCP specification. Additional changes: - Deprecate WebMvcSseServerTransport and StdioServerTransport - Add corresponding deprecated test classes to maintain backward compatibility - Increase test timeouts from 300ms to 1000ms/2000ms for more reliable testing - Minor rename of ClientMcpTransport to McpClientTransport for naming consistency Signed-off-by: Christian Tzolov --- .../transport/WebFluxSseClientTransport.java | 4 +- .../client/WebFluxSseMcpAsyncClientTests.java | 4 +- .../client/WebFluxSseMcpSyncClientTests.java | 4 +- .../transport/WebMvcSseServerTransport.java | 5 +- .../WebMvcSseServerTransportProvider.java | 399 ++++++++++++++ ...seAsyncServerTransportDeprecatedTests.java | 118 ++++ .../WebMvcSseAsyncServerTransportTests.java | 26 +- .../WebMvcSseIntegrationDeprecatedTests.java | 508 ++++++++++++++++++ .../server/WebMvcSseIntegrationTests.java | 195 ++++--- ...SseSyncServerTransportDeprecatedTests.java | 118 ++++ .../WebMvcSseSyncServerTransportTests.java | 25 +- .../client/AbstractMcpAsyncClientTests.java | 4 +- .../client/AbstractMcpSyncClientTests.java | 2 +- .../HttpClientSseClientTransport.java | 5 +- .../transport/StdioClientTransport.java | 4 +- .../transport/StdioServerTransport.java | 3 + .../spec/ClientMcpTransport.java | 4 - .../spec/McpClientTransport.java | 5 +- .../client/AbstractMcpAsyncClientTests.java | 4 +- .../client/AbstractMcpSyncClientTests.java | 2 +- .../client/HttpSseMcpAsyncClientTests.java | 4 +- .../client/HttpSseMcpSyncClientTests.java | 4 +- .../client/StdioMcpAsyncClientTests.java | 4 +- .../client/StdioMcpSyncClientTests.java | 4 +- 24 files changed, 1313 insertions(+), 142 deletions(-) create mode 100644 mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java create mode 100644 mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportDeprecatedTests.java create mode 100644 mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationDeprecatedTests.java create mode 100644 mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportDeprecatedTests.java diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java index 8ea65fd7..b0dfa89c 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java @@ -9,7 +9,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; @@ -58,7 +58,7 @@ * "https://spec.modelcontextprotocol.io/specification/basic/transports/#http-with-sse">MCP * HTTP with SSE Transport Specification */ -public class WebFluxSseClientTransport implements ClientMcpTransport { +public class WebFluxSseClientTransport implements McpClientTransport { private static final Logger logger = LoggerFactory.getLogger(WebFluxSseClientTransport.class); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java index 0dccb27a..2dd587d4 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java @@ -7,7 +7,7 @@ import java.time.Duration; import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; @@ -32,7 +32,7 @@ class WebFluxSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected ClientMcpTransport createMcpTransport() { + protected McpClientTransport createMcpTransport() { return new WebFluxSseClientTransport(WebClient.builder().baseUrl(host)); } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java index f5cab7b7..72b390dd 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java @@ -7,7 +7,7 @@ import java.time.Duration; import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; @@ -32,7 +32,7 @@ class WebFluxSseMcpSyncClientTests extends AbstractMcpSyncClientTests { .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected ClientMcpTransport createMcpTransport() { + protected McpClientTransport createMcpTransport() { return new WebFluxSseClientTransport(WebClient.builder().baseUrl(host)); } diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransport.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransport.java index 00928ec7..23193d10 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransport.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransport.java @@ -33,6 +33,9 @@ * a bridge between synchronous WebMVC operations and reactive programming patterns to * maintain compatibility with the reactive transport interface. * + * @deprecated This class will be removed in 0.9.0. Use + * {@link WebMvcSseServerTransportProvider}. + * *

    * Key features: *

      @@ -57,12 +60,12 @@ * This implementation uses {@link ConcurrentHashMap} to safely manage multiple client * sessions in a thread-safe manner. Each client session is assigned a unique ID and * maintains its own SSE connection. - * * @author Christian Tzolov * @author Alexandros Pappas * @see ServerMcpTransport * @see RouterFunction */ +@Deprecated public class WebMvcSseServerTransport implements ServerMcpTransport { private static final Logger logger = LoggerFactory.getLogger(WebMvcSseServerTransport.class); diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java new file mode 100644 index 00000000..65416b25 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -0,0 +1,399 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.time.Duration; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpStatus; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.RouterFunctions; +import org.springframework.web.servlet.function.ServerRequest; +import org.springframework.web.servlet.function.ServerResponse; +import org.springframework.web.servlet.function.ServerResponse.SseBuilder; + +/** + * Server-side implementation of the Model Context Protocol (MCP) transport layer using + * HTTP with Server-Sent Events (SSE) through Spring WebMVC. This implementation provides + * a bridge between synchronous WebMVC operations and reactive programming patterns to + * maintain compatibility with the reactive transport interface. + * + *

      + * Key features: + *

        + *
      • Implements bidirectional communication using HTTP POST for client-to-server + * messages and SSE for server-to-client messages
      • + *
      • Manages client sessions with unique IDs for reliable message delivery
      • + *
      • Supports graceful shutdown with proper session cleanup
      • + *
      • Provides JSON-RPC message handling through configured endpoints
      • + *
      • Includes built-in error handling and logging
      • + *
      + * + *

      + * The transport operates on two main endpoints: + *

        + *
      • {@code /sse} - The SSE endpoint where clients establish their event stream + * connection
      • + *
      • A configurable message endpoint where clients send their JSON-RPC messages via HTTP + * POST
      • + *
      + * + *

      + * This implementation uses {@link ConcurrentHashMap} to safely manage multiple client + * sessions in a thread-safe manner. Each client session is assigned a unique ID and + * maintains its own SSE connection. + * + * @author Christian Tzolov + * @author Alexandros Pappas + * @see McpServerTransportProvider + * @see RouterFunction + */ +public class WebMvcSseServerTransportProvider implements McpServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(WebMvcSseServerTransportProvider.class); + + /** + * Event type for JSON-RPC messages sent through the SSE connection. + */ + public static final String MESSAGE_EVENT_TYPE = "message"; + + /** + * Event type for sending the message endpoint URI to clients. + */ + public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + /** + * Default SSE endpoint path as specified by the MCP transport specification. + */ + public static final String DEFAULT_SSE_ENDPOINT = "/sse"; + + private final ObjectMapper objectMapper; + + private final String messageEndpoint; + + private final String sseEndpoint; + + private final RouterFunction routerFunction; + + private McpServerSession.Factory sessionFactory; + + /** + * Map of active client sessions, keyed by session ID. + */ + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + /** + * Flag indicating if the transport is shutting down. + */ + private volatile boolean isClosing = false; + + /** + * Constructs a new WebMvcSseServerTransportProvider instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of messages. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP POST. This endpoint will be communicated to clients through the + * SSE connection's initial endpoint event. + * @param sseEndpoint The endpoint URI where clients establish their SSE connections. + * @throws IllegalArgumentException if any parameter is null + */ + public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); + + this.objectMapper = objectMapper; + this.messageEndpoint = messageEndpoint; + this.sseEndpoint = sseEndpoint; + this.routerFunction = RouterFunctions.route() + .GET(this.sseEndpoint, this::handleSseConnection) + .POST(this.messageEndpoint, this::handleMessage) + .build(); + } + + /** + * Constructs a new WebMvcSseServerTransportProvider instance with the default SSE + * endpoint. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of messages. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP POST. This endpoint will be communicated to clients through the + * SSE connection's initial endpoint event. + * @throws IllegalArgumentException if either objectMapper or messageEndpoint is null + */ + public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { + this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); + } + + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + /** + * Broadcasts a notification to all connected clients through their SSE connections. + * The message is serialized to JSON and sent as an SSE event with type "message". If + * any errors occur during sending to a particular client, they are logged but don't + * prevent sending to other clients. + * @param method The method name for the notification + * @param params The parameters for the notification + * @return A Mono that completes when the broadcast attempt is finished + */ + @Override + public Mono notifyClients(String method, Map params) { + if (sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + + return Flux.fromIterable(sessions.values()) + .flatMap(session -> session.sendNotification(method, params) + .doOnError( + e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())) + .onErrorComplete()) + .then(); + } + + /** + * Initiates a graceful shutdown of the transport. This method: + *

        + *
      • Sets the closing flag to prevent new connections
      • + *
      • Closes all active SSE connections
      • + *
      • Removes all session records
      • + *
      + * @return A Mono that completes when all cleanup operations are finished + */ + @Override + public Mono closeGracefully() { + return Flux.fromIterable(sessions.values()).doFirst(() -> { + this.isClosing = true; + logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); + }) + .flatMap(McpServerSession::closeGracefully) + .then() + .doOnSuccess(v -> logger.debug("Graceful shutdown completed")); + } + + /** + * Returns the RouterFunction that defines the HTTP endpoints for this transport. The + * router function handles two endpoints: + *
        + *
      • GET /sse - For establishing SSE connections
      • + *
      • POST [messageEndpoint] - For receiving JSON-RPC messages from clients
      • + *
      + * @return The configured RouterFunction for handling HTTP requests + */ + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + /** + * Handles new SSE connection requests from clients by creating a new session and + * establishing an SSE connection. This method: + *
        + *
      • Generates a unique session ID
      • + *
      • Creates a new session with a WebMvcMcpSessionTransport
      • + *
      • Sends an initial endpoint event to inform the client where to send + * messages
      • + *
      • Maintains the session in the sessions map
      • + *
      + * @param request The incoming server request + * @return A ServerResponse configured for SSE communication, or an error response if + * the server is shutting down or the connection fails + */ + private ServerResponse handleSseConnection(ServerRequest request) { + if (this.isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); + } + + String sessionId = UUID.randomUUID().toString(); + logger.debug("Creating new SSE connection for session: {}", sessionId); + + // Send initial endpoint event + try { + return ServerResponse.sse(sseBuilder -> { + sseBuilder.onComplete(() -> { + logger.debug("SSE connection completed for session: {}", sessionId); + sessions.remove(sessionId); + }); + sseBuilder.onTimeout(() -> { + logger.debug("SSE connection timed out for session: {}", sessionId); + sessions.remove(sessionId); + }); + + WebMvcMcpSessionTransport sessionTransport = new WebMvcMcpSessionTransport(sessionId, sseBuilder); + McpServerSession session = sessionFactory.create(sessionTransport); + this.sessions.put(sessionId, session); + + try { + sseBuilder.id(sessionId) + .event(ENDPOINT_EVENT_TYPE) + .data(messageEndpoint + "?sessionId=" + sessionId); + } + catch (Exception e) { + logger.error("Failed to send initial endpoint event: {}", e.getMessage()); + sseBuilder.error(e); + } + }, Duration.ZERO); + } + catch (Exception e) { + logger.error("Failed to send initial endpoint event to session {}: {}", sessionId, e.getMessage()); + sessions.remove(sessionId); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Handles incoming JSON-RPC messages from clients. This method: + *
        + *
      • Deserializes the request body into a JSON-RPC message
      • + *
      • Processes the message through the session's handle method
      • + *
      • Returns appropriate HTTP responses based on the processing result
      • + *
      + * @param request The incoming server request containing the JSON-RPC message + * @return A ServerResponse indicating success (200 OK) or appropriate error status + * with error details in case of failures + */ + private ServerResponse handleMessage(ServerRequest request) { + if (this.isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); + } + + if (!request.param("sessionId").isPresent()) { + return ServerResponse.badRequest().body(new McpError("Session ID missing in message endpoint")); + } + + String sessionId = request.param("sessionId").get(); + McpServerSession session = sessions.get(sessionId); + + if (session == null) { + return ServerResponse.status(HttpStatus.NOT_FOUND).body(new McpError("Session not found: " + sessionId)); + } + + try { + String body = request.body(String.class); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); + + // Process the message through the session's handle method + session.handle(message).block(); // Block for WebMVC compatibility + + return ServerResponse.ok().build(); + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + return ServerResponse.badRequest().body(new McpError("Invalid message format")); + } + catch (Exception e) { + logger.error("Error handling message: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).body(new McpError(e.getMessage())); + } + } + + /** + * Implementation of McpServerTransport for WebMVC SSE sessions. This class handles + * the transport-level communication for a specific client session. + */ + private class WebMvcMcpSessionTransport implements McpServerTransport { + + private final String sessionId; + + private final SseBuilder sseBuilder; + + /** + * Creates a new session transport with the specified ID and SSE builder. + * @param sessionId The unique identifier for this session + * @param sseBuilder The SSE builder for sending server events to the client + */ + WebMvcMcpSessionTransport(String sessionId, SseBuilder sseBuilder) { + this.sessionId = sessionId; + this.sseBuilder = sseBuilder; + logger.debug("Session transport {} initialized with SSE builder", sessionId); + } + + /** + * Sends a JSON-RPC message to the client through the SSE connection. + * @param message The JSON-RPC message to send + * @return A Mono that completes when the message has been sent + */ + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return Mono.fromRunnable(() -> { + try { + String jsonText = objectMapper.writeValueAsString(message); + sseBuilder.id(sessionId).event(MESSAGE_EVENT_TYPE).data(jsonText); + logger.debug("Message sent to session {}", sessionId); + } + catch (Exception e) { + logger.error("Failed to send message to session {}: {}", sessionId, e.getMessage()); + sseBuilder.error(e); + } + }); + } + + /** + * Converts data from one type to another using the configured ObjectMapper. + * @param data The source data object to convert + * @param typeRef The target type reference + * @return The converted object of type T + * @param The target type + */ + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + /** + * Initiates a graceful shutdown of the transport. + * @return A Mono that completes when the shutdown is complete + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + logger.debug("Closing session transport: {}", sessionId); + try { + sseBuilder.complete(); + logger.debug("Successfully completed SSE builder for session {}", sessionId); + } + catch (Exception e) { + logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage()); + } + }); + } + + /** + * Closes the transport immediately. + */ + @Override + public void close() { + try { + sseBuilder.complete(); + logger.debug("Successfully completed SSE builder for session {}", sessionId); + } + catch (Exception e) { + logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage()); + } + } + + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportDeprecatedTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportDeprecatedTests.java new file mode 100644 index 00000000..c3f0e322 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportDeprecatedTests.java @@ -0,0 +1,118 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.Timeout; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.servlet.DispatcherServlet; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +@Deprecated +@Timeout(15) +class WebMvcSseAsyncServerTransportDeprecatedTests extends AbstractMcpAsyncServerDeprecatedTests { + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private static final int PORT = 8181; + + private Tomcat tomcat; + + private WebMvcSseServerTransport transport; + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcSseServerTransport webMvcSseServerTransport() { + return new WebMvcSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + } + + @Bean + public RouterFunction routerFunction(WebMvcSseServerTransport transport) { + return transport.getRouterFunction(); + } + + } + + private AnnotationConfigWebApplicationContext appContext; + + @Override + protected ServerMcpTransport createMcpTransport() { + // Set up Tomcat first + tomcat = new Tomcat(); + tomcat.setPort(PORT); + + // Set Tomcat base directory to java.io.tmpdir to avoid permission issues + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + // Use the same directory for document base + Context context = tomcat.addContext("", baseDir); + + // Create and configure Spring WebMvc context + appContext = new AnnotationConfigWebApplicationContext(); + appContext.register(TestConfig.class); + appContext.setServletContext(context.getServletContext()); + appContext.refresh(); + + // Get the transport from Spring context + transport = appContext.getBean(WebMvcSseServerTransport.class); + + // Create DispatcherServlet with our Spring context + DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); + // dispatcherServlet.setThrowExceptionIfNoHandlerFound(true); + + // Add servlet to Tomcat and get the wrapper + var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); + wrapper.setLoadOnStartup(1); + context.addServletMappingDecoded("/*", "dispatcherServlet"); + + try { + tomcat.start(); + tomcat.getConnector(); // Create and start the connector + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + return transport; + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (transport != null) { + transport.closeGracefully().block(); + } + if (appContext != null) { + appContext.close(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java index 00649f0f..08d5de67 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java @@ -5,8 +5,8 @@ package io.modelcontextprotocol.server; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebMvcSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.apache.catalina.Context; import org.apache.catalina.LifecycleException; import org.apache.catalina.startup.Tomcat; @@ -21,7 +21,7 @@ import org.springframework.web.servlet.function.ServerResponse; @Timeout(15) -class WebMvcSseAsyncServerTransportTests extends AbstractMcpAsyncServerDeprecatedTests { +class WebMvcSseAsyncServerTransportTests extends AbstractMcpAsyncServerTests { private static final String MESSAGE_ENDPOINT = "/mcp/message"; @@ -29,20 +29,20 @@ class WebMvcSseAsyncServerTransportTests extends AbstractMcpAsyncServerDeprecate private Tomcat tomcat; - private WebMvcSseServerTransport transport; + private McpServerTransportProvider transportProvider; @Configuration @EnableWebMvc static class TestConfig { @Bean - public WebMvcSseServerTransport webMvcSseServerTransport() { - return new WebMvcSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { + return new WebMvcSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); } @Bean - public RouterFunction routerFunction(WebMvcSseServerTransport transport) { - return transport.getRouterFunction(); + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); } } @@ -50,7 +50,7 @@ public RouterFunction routerFunction(WebMvcSseServerTransport tr private AnnotationConfigWebApplicationContext appContext; @Override - protected ServerMcpTransport createMcpTransport() { + protected McpServerTransportProvider createMcpTransportProvider() { // Set up Tomcat first tomcat = new Tomcat(); tomcat.setPort(PORT); @@ -69,7 +69,7 @@ protected ServerMcpTransport createMcpTransport() { appContext.refresh(); // Get the transport from Spring context - transport = appContext.getBean(WebMvcSseServerTransport.class); + transportProvider = appContext.getBean(WebMvcSseServerTransportProvider.class); // Create DispatcherServlet with our Spring context DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); @@ -88,7 +88,7 @@ protected ServerMcpTransport createMcpTransport() { throw new RuntimeException("Failed to start Tomcat", e); } - return transport; + return transportProvider; } @Override @@ -97,8 +97,8 @@ protected void onStart() { @Override protected void onClose() { - if (transport != null) { - transport.closeGracefully().block(); + if (transportProvider != null) { + transportProvider.closeGracefully().block(); } if (appContext != null) { appContext.close(); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationDeprecatedTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationDeprecatedTests.java new file mode 100644 index 00000000..f2b593d8 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationDeprecatedTests.java @@ -0,0 +1,508 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.test.StepVerifier; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.client.RestClient; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.servlet.DispatcherServlet; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.awaitility.Awaitility.await; + +@Deprecated +public class WebMvcSseIntegrationDeprecatedTests { + + private static final int PORT = 8183; + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private WebMvcSseServerTransport mcpServerTransport; + + McpClient.SyncSpec clientBuilder; + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcSseServerTransport webMvcSseServerTransport() { + return new WebMvcSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + } + + @Bean + public RouterFunction routerFunction(WebMvcSseServerTransport transport) { + return transport.getRouterFunction(); + } + + } + + private Tomcat tomcat; + + private AnnotationConfigWebApplicationContext appContext; + + @BeforeEach + public void before() { + + // Set up Tomcat first + tomcat = new Tomcat(); + tomcat.setPort(PORT); + + // Set Tomcat base directory to java.io.tmpdir to avoid permission issues + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + // Use the same directory for document base + Context context = tomcat.addContext("", baseDir); + + // Create and configure Spring WebMvc context + appContext = new AnnotationConfigWebApplicationContext(); + appContext.register(TestConfig.class); + appContext.setServletContext(context.getServletContext()); + appContext.refresh(); + + // Get the transport from Spring context + mcpServerTransport = appContext.getBean(WebMvcSseServerTransport.class); + + // Create DispatcherServlet with our Spring context + DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); + // dispatcherServlet.setThrowExceptionIfNoHandlerFound(true); + + // Add servlet to Tomcat and get the wrapper + var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); + wrapper.setLoadOnStartup(1); + wrapper.setAsyncSupported(true); + context.addServletMappingDecoded("/*", "dispatcherServlet"); + + try { + // Configure and start the connector with async support + var connector = tomcat.getConnector(); + connector.setAsyncTimeout(3000); // 3 seconds timeout for async requests + tomcat.start(); + assertThat(tomcat.getServer().getState() == LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + this.clientBuilder = McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT)); + } + + @AfterEach + public void after() { + if (mcpServerTransport != null) { + mcpServerTransport.closeGracefully().block(); + } + if (appContext != null) { + appContext.close(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + // --------------------------------------- + // Sampling Tests + // --------------------------------------- + @Test + void testCreateMessageWithoutInitialization() { + var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); + + var messages = List + .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); + var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + + var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, + McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); + + StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized. Call the initialize method first!"); + }); + } + + @Test + void testCreateMessageWithoutSamplingCapabilities() { + + var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); + + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); + + InitializeResult initResult = client.initialize(); + assertThat(initResult).isNotNull(); + + var messages = List + .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); + var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + + var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, + McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); + + StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + }); + } + + @Test + void testCreateMessageSuccess() throws InterruptedException { + + var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + InitializeResult initResult = client.initialize(); + assertThat(initResult).isNotNull(); + + var messages = List + .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); + var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + + var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, + McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); + + StepVerifier.create(mcpAsyncServer.createMessage(request)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + @Test + void testRootsSuccess() { + List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransport) + .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + assertThat(mcpServer.listRoots().roots()).containsAll(roots); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Remove a root + mcpClient.removeRoot(roots.get(0).uri()); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); + }); + + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testRootsWithoutCapability() { + var mcpServer = McpServer.sync(mcpServerTransport).rootsChangeConsumer(rootsUpdate -> { + }).build(); + + // Create client without roots capability + // No roots capability + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Attempt to list roots should fail + assertThatThrownBy(() -> mcpServer.listRoots().roots()).isInstanceOf(McpError.class) + .hasMessage("Roots not supported"); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testRootsWithEmptyRootsList() { + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransport) + .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(List.of()) // Empty roots list + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testRootsWithMultipleConsumers() { + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef1 = new AtomicReference<>(); + AtomicReference> rootsRef2 = new AtomicReference<>(); + + var mcpServer = McpServer.sync(mcpServerTransport) + .rootsChangeConsumer(rootsUpdate -> rootsRef1.set(rootsUpdate)) + .rootsChangeConsumer(rootsUpdate -> rootsRef2.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testRootsServerCloseWithActiveSubscription() { + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransport) + .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Close server while subscription is active + mcpServer.close(); + + // Verify client can handle server closure gracefully + mcpClient.close(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @Test + void testToolCallSuccess() { + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpServerTransport) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + var mcpClient = clientBuilder.build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testToolListChangeHandlingSuccess() { + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpServerTransport) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + rootsRef.set(toolsUpdate); + }).build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + mcpServer.notifyToolsListChanged(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); + }); + + // Remove a tool + mcpServer.removeTool("tool1"); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + + // Add a new tool + McpServerFeatures.SyncToolRegistration tool2 = new McpServerFeatures.SyncToolRegistration( + new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), request -> callResponse); + + mcpServer.addTool(tool2); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testInitialize() { + + var mcpServer = McpServer.sync(mcpServerTransport).build(); + + var mcpClient = clientBuilder.build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.close(); + mcpServer.close(); + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index 62f69637..7ba9ccc1 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -12,7 +12,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.server.transport.WebMvcSseServerTransport; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -31,6 +31,7 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; import reactor.test.StepVerifier; import org.springframework.context.annotation.Bean; @@ -43,8 +44,8 @@ import org.springframework.web.servlet.function.ServerResponse; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.awaitility.Awaitility.await; +import static org.mockito.Mockito.mock; public class WebMvcSseIntegrationTests { @@ -52,7 +53,7 @@ public class WebMvcSseIntegrationTests { private static final String MESSAGE_ENDPOINT = "/mcp/message"; - private WebMvcSseServerTransport mcpServerTransport; + private WebMvcSseServerTransportProvider mcpServerTransportProvider; McpClient.SyncSpec clientBuilder; @@ -61,13 +62,13 @@ public class WebMvcSseIntegrationTests { static class TestConfig { @Bean - public WebMvcSseServerTransport webMvcSseServerTransport() { - return new WebMvcSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { + return new WebMvcSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); } @Bean - public RouterFunction routerFunction(WebMvcSseServerTransport transport) { - return transport.getRouterFunction(); + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); } } @@ -97,7 +98,7 @@ public void before() { appContext.refresh(); // Get the transport from Spring context - mcpServerTransport = appContext.getBean(WebMvcSseServerTransport.class); + mcpServerTransportProvider = appContext.getBean(WebMvcSseServerTransportProvider.class); // Create DispatcherServlet with our Spring context DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); @@ -125,8 +126,8 @@ public void before() { @AfterEach public void after() { - if (mcpServerTransport != null) { - mcpServerTransport.closeGracefully().block(); + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); } if (appContext != null) { appContext.close(); @@ -146,49 +147,36 @@ public void after() { // Sampling Tests // --------------------------------------- @Test - void testCreateMessageWithoutInitialization() { - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + void testCreateMessageWithoutSamplingCapabilities() { - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized. Call the initialize method first!"); - }); - } + exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); - @Test - void testCreateMessageWithoutSamplingCapabilities() { + return Mono.just(mock(CallToolResult.class)); + }); - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); + McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + // Create client without sampling capabilities var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); - InitializeResult initResult = client.initialize(); - assertThat(initResult).isNotNull(); - - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + assertThat(client.initialize()).isNotNull(); - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) .hasMessage("Client must be configured with sampling capabilities"); - }); + } } @Test void testCreateMessageSuccess() throws InterruptedException { - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); + // Client Function samplingHandler = request -> { assertThat(request.messages()).hasSize(1); @@ -198,29 +186,54 @@ void testCreateMessageSuccess() throws InterruptedException { CreateMessageResult.StopReason.STOP_SEQUENCE); }; - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) .capabilities(ClientCapabilities.builder().sampling().build()) .sampling(samplingHandler) .build(); - InitializeResult initResult = client.initialize(); + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var messages = List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message"))); + var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + + var craeteMessageRequest = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, + McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), + Map.of()); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); assertThat(initResult).isNotNull(); - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); } // --------------------------------------- @@ -231,8 +244,8 @@ void testRootsSuccess() { List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) @@ -244,8 +257,6 @@ void testRootsSuccess() { assertThat(rootsRef.get()).isNull(); - assertThat(mcpServer.listRoots().roots()).containsAll(roots); - mcpClient.rootsListChangedNotification(); await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { @@ -273,29 +284,42 @@ void testRootsSuccess() { @Test void testRootsWithoutCapability() { - var mcpServer = McpServer.sync(mcpServerTransport).rootsChangeConsumer(rootsUpdate -> { - }).build(); + + McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.listRoots(); // try to list roots + + return mock(CallToolResult.class); + }); + + var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { + }).tools(tool).build(); // Create client without roots capability // No roots capability var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + assertThat(mcpClient.initialize()).isNotNull(); // Attempt to list roots should fail - assertThatThrownBy(() -> mcpServer.listRoots().roots()).isInstanceOf(McpError.class) - .hasMessage("Roots not supported"); + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } mcpClient.close(); mcpServer.close(); } @Test - void testRootsWithEmptyRootsList() { + void testRootsNotifciationWithEmptyRootsList() { AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) @@ -316,23 +340,22 @@ void testRootsWithEmptyRootsList() { } @Test - void testRootsWithMultipleConsumers() { + void testRootsWithMultipleHandlers() { List roots = List.of(new Root("uri1://", "root1")); AtomicReference> rootsRef1 = new AtomicReference<>(); AtomicReference> rootsRef2 = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef1.set(rootsUpdate)) - .rootsChangeConsumer(rootsUpdate -> rootsRef2.set(rootsUpdate)) + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) .build(); var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(roots) .build(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + assertThat(mcpClient.initialize()).isNotNull(); mcpClient.rootsListChangedNotification(); @@ -350,8 +373,8 @@ void testRootsServerCloseWithActiveSubscription() { List roots = List.of(new Root("uri1://", "root1")); AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) @@ -390,8 +413,8 @@ void testRootsServerCloseWithActiveSubscription() { void testToolCallSuccess() { var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { + McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -402,7 +425,7 @@ void testToolCallSuccess() { return callResponse; }); - var mcpServer = McpServer.sync(mcpServerTransport) + var mcpServer = McpServer.sync(mcpServerTransportProvider) .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(tool1) .build(); @@ -427,8 +450,8 @@ void testToolCallSuccess() { void testToolListChangeHandlingSuccess() { var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { + McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -439,7 +462,7 @@ void testToolListChangeHandlingSuccess() { return callResponse; }); - var mcpServer = McpServer.sync(mcpServerTransport) + var mcpServer = McpServer.sync(mcpServerTransportProvider) .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(tool1) .build(); @@ -477,8 +500,8 @@ void testToolListChangeHandlingSuccess() { }); // Add a new tool - McpServerFeatures.SyncToolRegistration tool2 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), request -> callResponse); + McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), (exchange, request) -> callResponse); mcpServer.addTool(tool2); @@ -493,7 +516,7 @@ void testToolListChangeHandlingSuccess() { @Test void testInitialize() { - var mcpServer = McpServer.sync(mcpServerTransport).build(); + var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); var mcpClient = clientBuilder.build(); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportDeprecatedTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportDeprecatedTests.java new file mode 100644 index 00000000..8656665e --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportDeprecatedTests.java @@ -0,0 +1,118 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.Timeout; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.servlet.DispatcherServlet; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +@Deprecated +@Timeout(15) +class WebMvcSseSyncServerTransportDeprecatedTests extends AbstractMcpSyncServerDeprecatedTests { + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private static final int PORT = 8181; + + private Tomcat tomcat; + + private WebMvcSseServerTransport transport; + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcSseServerTransport webMvcSseServerTransport() { + return new WebMvcSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + } + + @Bean + public RouterFunction routerFunction(WebMvcSseServerTransport transport) { + return transport.getRouterFunction(); + } + + } + + private AnnotationConfigWebApplicationContext appContext; + + @Override + protected ServerMcpTransport createMcpTransport() { + // Set up Tomcat first + tomcat = new Tomcat(); + tomcat.setPort(PORT); + + // Set Tomcat base directory to java.io.tmpdir to avoid permission issues + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + // Use the same directory for document base + Context context = tomcat.addContext("", baseDir); + + // Create and configure Spring WebMvc context + appContext = new AnnotationConfigWebApplicationContext(); + appContext.register(TestConfig.class); + appContext.setServletContext(context.getServletContext()); + appContext.refresh(); + + // Get the transport from Spring context + transport = appContext.getBean(WebMvcSseServerTransport.class); + + // Create DispatcherServlet with our Spring context + DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); + // dispatcherServlet.setThrowExceptionIfNoHandlerFound(true); + + // Add servlet to Tomcat and get the wrapper + var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); + wrapper.setLoadOnStartup(1); + context.addServletMappingDecoded("/*", "dispatcherServlet"); + + try { + tomcat.start(); + tomcat.getConnector(); // Create and start the connector + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + return transport; + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (transport != null) { + transport.closeGracefully().block(); + } + if (appContext != null) { + appContext.close(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java index be843a87..b85bed37 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java @@ -5,8 +5,7 @@ package io.modelcontextprotocol.server; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebMvcSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; import org.apache.catalina.Context; import org.apache.catalina.LifecycleException; import org.apache.catalina.startup.Tomcat; @@ -21,7 +20,7 @@ import org.springframework.web.servlet.function.ServerResponse; @Timeout(15) -class WebMvcSseSyncServerTransportTests extends AbstractMcpSyncServerDeprecatedTests { +class WebMvcSseSyncServerTransportTests extends AbstractMcpSyncServerTests { private static final String MESSAGE_ENDPOINT = "/mcp/message"; @@ -29,20 +28,20 @@ class WebMvcSseSyncServerTransportTests extends AbstractMcpSyncServerDeprecatedT private Tomcat tomcat; - private WebMvcSseServerTransport transport; + private WebMvcSseServerTransportProvider transportProvider; @Configuration @EnableWebMvc static class TestConfig { @Bean - public WebMvcSseServerTransport webMvcSseServerTransport() { - return new WebMvcSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { + return new WebMvcSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); } @Bean - public RouterFunction routerFunction(WebMvcSseServerTransport transport) { - return transport.getRouterFunction(); + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); } } @@ -50,7 +49,7 @@ public RouterFunction routerFunction(WebMvcSseServerTransport tr private AnnotationConfigWebApplicationContext appContext; @Override - protected ServerMcpTransport createMcpTransport() { + protected WebMvcSseServerTransportProvider createMcpTransportProvider() { // Set up Tomcat first tomcat = new Tomcat(); tomcat.setPort(PORT); @@ -69,7 +68,7 @@ protected ServerMcpTransport createMcpTransport() { appContext.refresh(); // Get the transport from Spring context - transport = appContext.getBean(WebMvcSseServerTransport.class); + transportProvider = appContext.getBean(WebMvcSseServerTransportProvider.class); // Create DispatcherServlet with our Spring context DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); @@ -88,7 +87,7 @@ protected ServerMcpTransport createMcpTransport() { throw new RuntimeException("Failed to start Tomcat", e); } - return transport; + return transportProvider; } @Override @@ -97,8 +96,8 @@ protected void onStart() { @Override protected void onClose() { - if (transport != null) { - transport.closeGracefully().block(); + if (transportProvider != null) { + transportProvider.closeGracefully().block(); } if (appContext != null) { appContext.close(); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 033139ad..ed29cf06 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -12,7 +12,7 @@ import java.util.function.Consumer; import java.util.function.Function; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; @@ -49,7 +49,7 @@ public abstract class AbstractMcpAsyncClientTests { private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!"; - abstract protected ClientMcpTransport createMcpTransport(); + abstract protected McpClientTransport createMcpTransport(); protected void onStart() { } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 032f8684..3c17c45e 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -11,7 +11,7 @@ import java.util.function.Consumer; import java.util.function.Function; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 35da5197..088990ca 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -7,6 +7,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.client.transport.FlowSseClient.SseEvent; import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; @@ -52,9 +53,9 @@ * * @author Christian Tzolov * @see io.modelcontextprotocol.spec.McpTransport - * @see io.modelcontextprotocol.spec.ClientMcpTransport + * @see io.modelcontextprotocol.spec.McpClientTransport */ -public class HttpClientSseClientTransport implements ClientMcpTransport { +public class HttpClientSseClientTransport implements McpClientTransport { private static final Logger logger = LoggerFactory.getLogger(HttpClientSseClientTransport.class); diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java index d35db3f8..8fdc0479 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java @@ -18,7 +18,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; import io.modelcontextprotocol.util.Assert; @@ -38,7 +38,7 @@ * @author Christian Tzolov * @author Dariusz Jędrzejczyk */ -public class StdioClientTransport implements ClientMcpTransport { +public class StdioClientTransport implements McpClientTransport { private static final Logger logger = LoggerFactory.getLogger(StdioClientTransport.class); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransport.java index e375cd10..14129c52 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransport.java @@ -33,7 +33,10 @@ * over stdin/stdout, with errors and debug information sent to stderr. * * @author Christian Tzolov + * @deprecated Use + * {@link io.modelcontextprotocol.server.transport.StdioServerTransportProvider} instead. */ +@Deprecated public class StdioServerTransport implements ServerMcpTransport { private static final Logger logger = LoggerFactory.getLogger(StdioServerTransport.class); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java index 702f01d6..8464b6ae 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java @@ -3,10 +3,6 @@ */ package io.modelcontextprotocol.spec; -import java.util.function.Function; - -import reactor.core.publisher.Mono; - /** * Marker interface for the client-side MCP transport. * diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java index fa90e96f..63aa1dbf 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java @@ -1,10 +1,13 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +*/ package io.modelcontextprotocol.spec; import java.util.function.Function; import reactor.core.publisher.Mono; -public interface McpClientTransport extends McpTransport { +public interface McpClientTransport extends ClientMcpTransport { @Override Mono connect(Function, Mono> handler); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 72038854..f5c90c16 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -13,7 +13,7 @@ import java.util.function.Function; import java.util.function.Supplier; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; @@ -51,7 +51,7 @@ public abstract class AbstractMcpAsyncClientTests { private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!"; - abstract protected ClientMcpTransport createMcpTransport(); + abstract protected McpClientTransport createMcpTransport(); protected void onStart() { } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 1c042bf2..43600db7 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -11,7 +11,7 @@ import java.util.function.Consumer; import java.util.function.Function; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java index ac0fef24..c2201533 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java @@ -7,7 +7,7 @@ import java.time.Duration; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; @@ -30,7 +30,7 @@ class HttpSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected ClientMcpTransport createMcpTransport() { + protected McpClientTransport createMcpTransport() { return new HttpClientSseClientTransport(host); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java index 8772e620..8b638fba 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java @@ -7,7 +7,7 @@ import java.time.Duration; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; @@ -30,7 +30,7 @@ class HttpSseMcpSyncClientTests extends AbstractMcpSyncClientTests { .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected ClientMcpTransport createMcpTransport() { + protected McpClientTransport createMcpTransport() { return new HttpClientSseClientTransport(host); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java index c285e2c6..95230942 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java @@ -8,7 +8,7 @@ import io.modelcontextprotocol.client.transport.ServerParameters; import io.modelcontextprotocol.client.transport.StdioClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; /** @@ -21,7 +21,7 @@ class StdioMcpAsyncClientTests extends AbstractMcpAsyncClientTests { @Override - protected ClientMcpTransport createMcpTransport() { + protected McpClientTransport createMcpTransport() { ServerParameters stdioParams = ServerParameters.builder("npx") .args("-y", "@modelcontextprotocol/server-everything", "dir") .build(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java index ebf10b9a..8f7ec15b 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java @@ -11,7 +11,7 @@ import io.modelcontextprotocol.client.transport.ServerParameters; import io.modelcontextprotocol.client.transport.StdioClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import reactor.core.publisher.Sinks; @@ -29,7 +29,7 @@ class StdioMcpSyncClientTests extends AbstractMcpSyncClientTests { @Override - protected ClientMcpTransport createMcpTransport() { + protected McpClientTransport createMcpTransport() { ServerParameters stdioParams = ServerParameters.builder("npx") .args("-y", "@modelcontextprotocol/server-everything", "dir") .build(); From 892f12f8de29b768aacd4d5443fd7442d6562f64 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 18 Mar 2025 08:54:53 +0100 Subject: [PATCH 13/20] refactor(mcp): replace HttpServletSseServerTransport with provider-based implementation Introduces HttpServletSseServerTransportProvider as a replacement for HttpServletSseServerTransport, following the provider pattern used by other transport implementations. The new implementation offers the same functionality but with a more consistent API aligned with the McpServerTransportProvider interface. - Mark HttpServletSseServerTransport as @Deprecated (to be removed in 0.9.0) - Add new HttpServletSseServerTransportProvider implementation - Update test classes to use the new provider-based implementation - Add separate test classes for deprecated implementation Signed-off-by: Christian Tzolov --- .../HttpServletSseServerTransport.java | 5 +- ...HttpServletSseServerTransportProvider.java | 432 +++++++++++++++ ...rvletSseMcpAsyncServerDeprecatedTests.java | 26 + .../server/ServletSseMcpAsyncServerTests.java | 12 +- ...ervletSseMcpSyncServerDeprecatedTests.java | 26 + .../server/ServletSseMcpSyncServerTests.java | 12 +- ...rverTransportProviderIntegrationTests.java | 493 ++++++++++++++++++ 7 files changed, 993 insertions(+), 13 deletions(-) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerDeprecatedTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerDeprecatedTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransport.java index 98b8ea58..fa5dcf1c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransport.java @@ -32,6 +32,9 @@ * specification. This implementation provides similar functionality to * WebFluxSseServerTransport but uses the traditional Servlet API instead of WebFlux. * + * @deprecated This class will be removed in 0.9.0. Use + * {@link HttpServletSseServerTransportProvider}. + * *

      * The transport handles two types of endpoints: *

        @@ -48,7 +51,6 @@ *
      • Graceful shutdown support
      • *
      • Error handling and response formatting
      • *
      - * * @author Christian Tzolov * @author Alexandros Pappas * @see ServerMcpTransport @@ -56,6 +58,7 @@ */ @WebServlet(asyncSupported = true) +@Deprecated public class HttpServletSseServerTransport extends HttpServlet implements ServerMcpTransport { /** Logger for this class */ diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java new file mode 100644 index 00000000..152462b1 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -0,0 +1,432 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server.transport; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.PrintWriter; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.ServletException; +import jakarta.servlet.annotation.WebServlet; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * A Servlet-based implementation of the MCP HTTP with Server-Sent Events (SSE) transport + * specification. This implementation provides similar functionality to + * WebFluxSseServerTransportProvider but uses the traditional Servlet API instead of + * WebFlux. + * + *

      + * The transport handles two types of endpoints: + *

        + *
      • SSE endpoint (/sse) - Establishes a long-lived connection for server-to-client + * events
      • + *
      • Message endpoint (configurable) - Handles client-to-server message requests
      • + *
      + * + *

      + * Features: + *

        + *
      • Asynchronous message handling using Servlet 6.0 async support
      • + *
      • Session management for multiple client connections
      • + *
      • Graceful shutdown support
      • + *
      • Error handling and response formatting
      • + *
      + * + * @author Christian Tzolov + * @author Alexandros Pappas + * @see McpServerTransportProvider + * @see HttpServlet + */ + +@WebServlet(asyncSupported = true) +public class HttpServletSseServerTransportProvider extends HttpServlet implements McpServerTransportProvider { + + /** Logger for this class */ + private static final Logger logger = LoggerFactory.getLogger(HttpServletSseServerTransportProvider.class); + + public static final String UTF_8 = "UTF-8"; + + public static final String APPLICATION_JSON = "application/json"; + + public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}"; + + /** Default endpoint path for SSE connections */ + public static final String DEFAULT_SSE_ENDPOINT = "/sse"; + + /** Event type for regular messages */ + public static final String MESSAGE_EVENT_TYPE = "message"; + + /** Event type for endpoint information */ + public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + /** JSON object mapper for serialization/deserialization */ + private final ObjectMapper objectMapper; + + /** The endpoint path for handling client messages */ + private final String messageEndpoint; + + /** The endpoint path for handling SSE connections */ + private final String sseEndpoint; + + /** Map of active client sessions, keyed by session ID */ + private final Map sessions = new ConcurrentHashMap<>(); + + /** Flag indicating if the transport is in the process of shutting down */ + private final AtomicBoolean isClosing = new AtomicBoolean(false); + + /** Session factory for creating new sessions */ + private McpServerSession.Factory sessionFactory; + + /** + * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE + * endpoint. + * @param objectMapper The JSON object mapper to use for message + * serialization/deserialization + * @param messageEndpoint The endpoint path where clients will send their messages + * @param sseEndpoint The endpoint path where clients will establish SSE connections + */ + public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, + String sseEndpoint) { + this.objectMapper = objectMapper; + this.messageEndpoint = messageEndpoint; + this.sseEndpoint = sseEndpoint; + } + + /** + * Creates a new HttpServletSseServerTransportProvider instance with the default SSE + * endpoint. + * @param objectMapper The JSON object mapper to use for message + * serialization/deserialization + * @param messageEndpoint The endpoint path where clients will send their messages + */ + public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { + this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); + } + + /** + * Sets the session factory for creating new sessions. + * @param sessionFactory The session factory to use + */ + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + /** + * Broadcasts a notification to all connected clients. + * @param method The method name for the notification + * @param params The parameters for the notification + * @return A Mono that completes when the broadcast attempt is finished + */ + @Override + public Mono notifyClients(String method, Map params) { + if (sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + + return Flux.fromIterable(sessions.values()) + .flatMap(session -> session.sendNotification(method, params) + .doOnError( + e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())) + .onErrorComplete()) + .then(); + } + + /** + * Handles GET requests to establish SSE connections. + *

      + * This method sets up a new SSE connection when a client connects to the SSE + * endpoint. It configures the response headers for SSE, creates a new session, and + * sends the initial endpoint information to the client. + * @param request The HTTP servlet request + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String pathInfo = request.getPathInfo(); + if (!sseEndpoint.equals(pathInfo)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + if (isClosing.get()) { + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return; + } + + response.setContentType("text/event-stream"); + response.setCharacterEncoding(UTF_8); + response.setHeader("Cache-Control", "no-cache"); + response.setHeader("Connection", "keep-alive"); + response.setHeader("Access-Control-Allow-Origin", "*"); + + String sessionId = UUID.randomUUID().toString(); + AsyncContext asyncContext = request.startAsync(); + asyncContext.setTimeout(0); + + PrintWriter writer = response.getWriter(); + + // Create a new session transport + HttpServletMcpSessionTransport sessionTransport = new HttpServletMcpSessionTransport(sessionId, asyncContext, + writer); + + // Create a new session using the session factory + McpServerSession session = sessionFactory.create(sessionTransport); + this.sessions.put(sessionId, session); + + // Send initial endpoint event + this.sendEvent(writer, ENDPOINT_EVENT_TYPE, messageEndpoint + "?sessionId=" + sessionId); + } + + /** + * Handles POST requests for client messages. + *

      + * This method processes incoming messages from clients, routes them through the + * session handler, and sends back the appropriate response. It handles error cases + * and formats error responses according to the MCP specification. + * @param request The HTTP servlet request + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doPost(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + if (isClosing.get()) { + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return; + } + + String pathInfo = request.getPathInfo(); + if (!messageEndpoint.equals(pathInfo)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + // Get the session ID from the request parameter + String sessionId = request.getParameter("sessionId"); + if (sessionId == null) { + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(HttpServletResponse.SC_BAD_REQUEST); + String jsonError = objectMapper.writeValueAsString(new McpError("Session ID missing in message endpoint")); + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + return; + } + + // Get the session from the sessions map + McpServerSession session = sessions.get(sessionId); + if (session == null) { + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(HttpServletResponse.SC_NOT_FOUND); + String jsonError = objectMapper.writeValueAsString(new McpError("Session not found: " + sessionId)); + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + return; + } + + try { + BufferedReader reader = request.getReader(); + StringBuilder body = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + body.append(line); + } + + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString()); + + // Process the message through the session's handle method + session.handle(message).block(); // Block for Servlet compatibility + + response.setStatus(HttpServletResponse.SC_OK); + } + catch (Exception e) { + logger.error("Error processing message: {}", e.getMessage()); + try { + McpError mcpError = new McpError(e.getMessage()); + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); + String jsonError = objectMapper.writeValueAsString(mcpError); + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + } + catch (IOException ex) { + logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage()); + response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Error processing message"); + } + } + } + + /** + * Initiates a graceful shutdown of the transport. + *

      + * This method marks the transport as closing and closes all active client sessions. + * New connection attempts will be rejected during shutdown. + * @return A Mono that completes when all sessions have been closed + */ + @Override + public Mono closeGracefully() { + isClosing.set(true); + logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); + + return Flux.fromIterable(sessions.values()).flatMap(McpServerSession::closeGracefully).then(); + } + + /** + * Sends an SSE event to a client. + * @param writer The writer to send the event through + * @param eventType The type of event (message or endpoint) + * @param data The event data + * @throws IOException If an error occurs while writing the event + */ + private void sendEvent(PrintWriter writer, String eventType, String data) throws IOException { + writer.write("event: " + eventType + "\n"); + writer.write("data: " + data + "\n\n"); + writer.flush(); + + if (writer.checkError()) { + throw new IOException("Client disconnected"); + } + } + + /** + * Cleans up resources when the servlet is being destroyed. + *

      + * This method ensures a graceful shutdown by closing all client connections before + * calling the parent's destroy method. + */ + @Override + public void destroy() { + closeGracefully().block(); + super.destroy(); + } + + /** + * Implementation of McpServerTransport for HttpServlet SSE sessions. This class + * handles the transport-level communication for a specific client session. + */ + private class HttpServletMcpSessionTransport implements McpServerTransport { + + private final String sessionId; + + private final AsyncContext asyncContext; + + private final PrintWriter writer; + + /** + * Creates a new session transport with the specified ID and SSE writer. + * @param sessionId The unique identifier for this session + * @param asyncContext The async context for the session + * @param writer The writer for sending server events to the client + */ + HttpServletMcpSessionTransport(String sessionId, AsyncContext asyncContext, PrintWriter writer) { + this.sessionId = sessionId; + this.asyncContext = asyncContext; + this.writer = writer; + logger.debug("Session transport {} initialized with SSE writer", sessionId); + } + + /** + * Sends a JSON-RPC message to the client through the SSE connection. + * @param message The JSON-RPC message to send + * @return A Mono that completes when the message has been sent + */ + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return Mono.fromRunnable(() -> { + try { + String jsonText = objectMapper.writeValueAsString(message); + sendEvent(writer, MESSAGE_EVENT_TYPE, jsonText); + logger.debug("Message sent to session {}", sessionId); + } + catch (Exception e) { + logger.error("Failed to send message to session {}: {}", sessionId, e.getMessage()); + sessions.remove(sessionId); + asyncContext.complete(); + } + }); + } + + /** + * Converts data from one type to another using the configured ObjectMapper. + * @param data The source data object to convert + * @param typeRef The target type reference + * @return The converted object of type T + * @param The target type + */ + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + /** + * Initiates a graceful shutdown of the transport. + * @return A Mono that completes when the shutdown is complete + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + logger.debug("Closing session transport: {}", sessionId); + try { + sessions.remove(sessionId); + asyncContext.complete(); + logger.debug("Successfully completed async context for session {}", sessionId); + } + catch (Exception e) { + logger.warn("Failed to complete async context for session {}: {}", sessionId, e.getMessage()); + } + }); + } + + /** + * Closes the transport immediately. + */ + @Override + public void close() { + try { + sessions.remove(sessionId); + asyncContext.complete(); + logger.debug("Successfully completed async context for session {}", sessionId); + } + catch (Exception e) { + logger.warn("Failed to complete async context for session {}: {}", sessionId, e.getMessage()); + } + } + + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerDeprecatedTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerDeprecatedTests.java new file mode 100644 index 00000000..2c80d45c --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerDeprecatedTests.java @@ -0,0 +1,26 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.Timeout; + +/** + * Tests for {@link McpAsyncServer} using {@link HttpServletSseServerTransport}. + * + * @author Christian Tzolov + */ +@Deprecated +@Timeout(15) // Giving extra time beyond the client timeout +class ServletSseMcpAsyncServerDeprecatedTests extends AbstractMcpAsyncServerDeprecatedTests { + + @Override + protected ServerMcpTransport createMcpTransport() { + return new HttpServletSseServerTransport(new ObjectMapper(), "/mcp/message"); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java index 7f7357a6..9de186b4 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java @@ -5,21 +5,21 @@ package io.modelcontextprotocol.server; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.HttpServletSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; /** - * Tests for {@link McpAsyncServer} using {@link HttpServletSseServerTransport}. + * Tests for {@link McpAsyncServer} using {@link HttpServletSseServerTransportProvider}. * * @author Christian Tzolov */ @Timeout(15) // Giving extra time beyond the client timeout -class ServletSseMcpAsyncServerTests extends AbstractMcpAsyncServerDeprecatedTests { +class ServletSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { @Override - protected ServerMcpTransport createMcpTransport() { - return new HttpServletSseServerTransport(new ObjectMapper(), "/mcp/message"); + protected McpServerTransportProvider createMcpTransportProvider() { + return new HttpServletSseServerTransportProvider(new ObjectMapper(), "/mcp/message"); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerDeprecatedTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerDeprecatedTests.java new file mode 100644 index 00000000..8cdd08c5 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerDeprecatedTests.java @@ -0,0 +1,26 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.Timeout; + +/** + * Tests for {@link McpSyncServer} using {@link HttpServletSseServerTransport}. + * + * @author Christian Tzolov + */ +@Deprecated +@Timeout(15) // Giving extra time beyond the client timeout +class ServletSseMcpSyncServerDeprecatedTests extends AbstractMcpSyncServerDeprecatedTests { + + @Override + protected ServerMcpTransport createMcpTransport() { + return new HttpServletSseServerTransport(new ObjectMapper(), "/mcp/message"); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java index 4507256c..60dc53a4 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java @@ -5,21 +5,21 @@ package io.modelcontextprotocol.server; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.HttpServletSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; /** - * Tests for {@link McpSyncServer} using {@link HttpServletSseServerTransport}. + * Tests for {@link McpSyncServer} using {@link HttpServletSseServerTransportProvider}. * * @author Christian Tzolov */ @Timeout(15) // Giving extra time beyond the client timeout -class ServletSseMcpSyncServerTests extends AbstractMcpSyncServerDeprecatedTests { +class ServletSseMcpSyncServerTests extends AbstractMcpSyncServerTests { @Override - protected ServerMcpTransport createMcpTransport() { - return new HttpServletSseServerTransport(new ObjectMapper(), "/mcp/message"); + protected McpServerTransportProvider createMcpTransportProvider() { + return new HttpServletSseServerTransportProvider(new ObjectMapper(), "/mcp/message"); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java new file mode 100644 index 00000000..290141bb --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -0,0 +1,493 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server.transport; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.web.client.RestClient; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; +import static org.mockito.Mockito.mock; + +public class HttpServletSseServerTransportProviderIntegrationTests { + + private static final int PORT = 8185; + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private HttpServletSseServerTransportProvider mcpServerTransportProvider; + + McpClient.SyncSpec clientBuilder; + + private Tomcat tomcat; + + @BeforeEach + public void before() { + tomcat = new Tomcat(); + tomcat.setPort(PORT); + + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + Context context = tomcat.addContext("", baseDir); + + // Create and configure the transport provider + mcpServerTransportProvider = new HttpServletSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + + // Add transport servlet to Tomcat + org.apache.catalina.Wrapper wrapper = context.createWrapper(); + wrapper.setName("mcpServlet"); + wrapper.setServlet(mcpServerTransportProvider); + wrapper.setLoadOnStartup(1); + wrapper.setAsyncSupported(true); + context.addChild(wrapper); + context.addServletMappingDecoded("/*", "mcpServlet"); + + try { + var connector = tomcat.getConnector(); + connector.setAsyncTimeout(3000); + tomcat.start(); + assertThat(tomcat.getServer().getState() == LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + this.clientBuilder = McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT)); + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + // --------------------------------------- + // Sampling Tests + // --------------------------------------- + @Test + @Disabled + void testCreateMessageWithoutSamplingCapabilities() { + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); + + return Mono.just(mock(CallToolResult.class)); + }); + + McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + } + } + + @Test + void testCreateMessageSuccess() throws InterruptedException { + + // Client + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var messages = List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message"))); + var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + + var craeteMessageRequest = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, + McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), + Map.of()); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + @Test + void testRootsSuccess() { + List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Remove a root + mcpClient.removeRoot(roots.get(0).uri()); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); + }); + + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testRootsWithoutCapability() { + + McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.listRoots(); // try to list roots + + return mock(CallToolResult.class); + }); + + var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { + }).tools(tool).build(); + + // Create client without roots capability + // No roots capability + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build(); + + assertThat(mcpClient.initialize()).isNotNull(); + + // Attempt to list roots should fail + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testRootsNotifciationWithEmptyRootsList() { + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(List.of()) // Empty roots list + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testRootsWithMultipleHandlers() { + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef1 = new AtomicReference<>(); + AtomicReference> rootsRef2 = new AtomicReference<>(); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + assertThat(mcpClient.initialize()).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testRootsServerCloseWithActiveSubscription() { + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Close server while subscription is active + mcpServer.close(); + + // Verify client can handle server closure gracefully + mcpClient.close(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @Test + void testToolCallSuccess() { + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + var mcpClient = clientBuilder.build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testToolListChangeHandlingSuccess() { + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + rootsRef.set(toolsUpdate); + }).build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + mcpServer.notifyToolsListChanged(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); + }); + + // Remove a tool + mcpServer.removeTool("tool1"); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + + // Add a new tool + McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), (exchange, request) -> callResponse); + + mcpServer.addTool(tool2); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testInitialize() { + + var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); + + var mcpClient = clientBuilder.build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.close(); + mcpServer.close(); + } + +} From cbf2b54fbcaf5ceca5abd527490684b7f0772ab0 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 18 Mar 2025 11:01:35 +0100 Subject: [PATCH 14/20] Revert the McpClientTransport breaking change back to ClientMcpTransport Signed-off-by: Christian Tzolov --- .../transport/WebFluxSseClientTransport.java | 4 +-- .../client/WebFluxSseMcpAsyncClientTests.java | 4 +-- .../client/WebFluxSseMcpSyncClientTests.java | 4 +-- .../client/AbstractMcpAsyncClientTests.java | 4 +-- .../client/AbstractMcpSyncClientTests.java | 2 +- .../HttpClientSseClientTransport.java | 29 +++++++++---------- .../transport/StdioClientTransport.java | 4 +-- .../client/AbstractMcpAsyncClientTests.java | 4 +-- .../client/AbstractMcpSyncClientTests.java | 2 +- .../client/HttpSseMcpAsyncClientTests.java | 4 +-- .../client/HttpSseMcpSyncClientTests.java | 4 +-- .../client/StdioMcpAsyncClientTests.java | 4 +-- .../client/StdioMcpSyncClientTests.java | 4 +-- 13 files changed, 36 insertions(+), 37 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java index b0dfa89c..8ea65fd7 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java @@ -9,7 +9,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.ClientMcpTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; @@ -58,7 +58,7 @@ * "https://spec.modelcontextprotocol.io/specification/basic/transports/#http-with-sse">MCP * HTTP with SSE Transport Specification */ -public class WebFluxSseClientTransport implements McpClientTransport { +public class WebFluxSseClientTransport implements ClientMcpTransport { private static final Logger logger = LoggerFactory.getLogger(WebFluxSseClientTransport.class); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java index 2dd587d4..0dccb27a 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java @@ -7,7 +7,7 @@ import java.time.Duration; import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; -import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.ClientMcpTransport; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; @@ -32,7 +32,7 @@ class WebFluxSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected McpClientTransport createMcpTransport() { + protected ClientMcpTransport createMcpTransport() { return new WebFluxSseClientTransport(WebClient.builder().baseUrl(host)); } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java index 72b390dd..f5cab7b7 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java @@ -7,7 +7,7 @@ import java.time.Duration; import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; -import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.ClientMcpTransport; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; @@ -32,7 +32,7 @@ class WebFluxSseMcpSyncClientTests extends AbstractMcpSyncClientTests { .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected McpClientTransport createMcpTransport() { + protected ClientMcpTransport createMcpTransport() { return new WebFluxSseClientTransport(WebClient.builder().baseUrl(host)); } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index ed29cf06..033139ad 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -12,7 +12,7 @@ import java.util.function.Consumer; import java.util.function.Function; -import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.ClientMcpTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; @@ -49,7 +49,7 @@ public abstract class AbstractMcpAsyncClientTests { private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!"; - abstract protected McpClientTransport createMcpTransport(); + abstract protected ClientMcpTransport createMcpTransport(); protected void onStart() { } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 3c17c45e..032f8684 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -11,7 +11,7 @@ import java.util.function.Consumer; import java.util.function.Function; -import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.ClientMcpTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 088990ca..1e7df31a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -3,19 +3,6 @@ */ package io.modelcontextprotocol.client.transport; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.client.transport.FlowSseClient.SseEvent; -import io.modelcontextprotocol.spec.ClientMcpTransport; -import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; -import io.modelcontextprotocol.util.Assert; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Mono; - import java.io.IOException; import java.net.URI; import java.net.http.HttpClient; @@ -28,6 +15,18 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.transport.FlowSseClient.SseEvent; +import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + /** * Server-Sent Events (SSE) implementation of the * {@link io.modelcontextprotocol.spec.McpTransport} that follows the MCP HTTP with SSE @@ -53,9 +52,9 @@ * * @author Christian Tzolov * @see io.modelcontextprotocol.spec.McpTransport - * @see io.modelcontextprotocol.spec.McpClientTransport + * @see io.modelcontextprotocol.spec.ClientMcpTransport */ -public class HttpClientSseClientTransport implements McpClientTransport { +public class HttpClientSseClientTransport implements ClientMcpTransport { private static final Logger logger = LoggerFactory.getLogger(HttpClientSseClientTransport.class); diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java index 8fdc0479..d35db3f8 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java @@ -18,7 +18,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.ClientMcpTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; import io.modelcontextprotocol.util.Assert; @@ -38,7 +38,7 @@ * @author Christian Tzolov * @author Dariusz Jędrzejczyk */ -public class StdioClientTransport implements McpClientTransport { +public class StdioClientTransport implements ClientMcpTransport { private static final Logger logger = LoggerFactory.getLogger(StdioClientTransport.class); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index f5c90c16..72038854 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -13,7 +13,7 @@ import java.util.function.Function; import java.util.function.Supplier; -import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.ClientMcpTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; @@ -51,7 +51,7 @@ public abstract class AbstractMcpAsyncClientTests { private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!"; - abstract protected McpClientTransport createMcpTransport(); + abstract protected ClientMcpTransport createMcpTransport(); protected void onStart() { } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 43600db7..1c042bf2 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -11,7 +11,7 @@ import java.util.function.Consumer; import java.util.function.Function; -import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.ClientMcpTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java index c2201533..ac0fef24 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java @@ -7,7 +7,7 @@ import java.time.Duration; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.ClientMcpTransport; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; @@ -30,7 +30,7 @@ class HttpSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected McpClientTransport createMcpTransport() { + protected ClientMcpTransport createMcpTransport() { return new HttpClientSseClientTransport(host); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java index 8b638fba..8772e620 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java @@ -7,7 +7,7 @@ import java.time.Duration; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.ClientMcpTransport; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; @@ -30,7 +30,7 @@ class HttpSseMcpSyncClientTests extends AbstractMcpSyncClientTests { .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected McpClientTransport createMcpTransport() { + protected ClientMcpTransport createMcpTransport() { return new HttpClientSseClientTransport(host); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java index 95230942..c285e2c6 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java @@ -8,7 +8,7 @@ import io.modelcontextprotocol.client.transport.ServerParameters; import io.modelcontextprotocol.client.transport.StdioClientTransport; -import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.ClientMcpTransport; import org.junit.jupiter.api.Timeout; /** @@ -21,7 +21,7 @@ class StdioMcpAsyncClientTests extends AbstractMcpAsyncClientTests { @Override - protected McpClientTransport createMcpTransport() { + protected ClientMcpTransport createMcpTransport() { ServerParameters stdioParams = ServerParameters.builder("npx") .args("-y", "@modelcontextprotocol/server-everything", "dir") .build(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java index 8f7ec15b..ebf10b9a 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java @@ -11,7 +11,7 @@ import io.modelcontextprotocol.client.transport.ServerParameters; import io.modelcontextprotocol.client.transport.StdioClientTransport; -import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.ClientMcpTransport; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import reactor.core.publisher.Sinks; @@ -29,7 +29,7 @@ class StdioMcpSyncClientTests extends AbstractMcpSyncClientTests { @Override - protected McpClientTransport createMcpTransport() { + protected ClientMcpTransport createMcpTransport() { ServerParameters stdioParams = ServerParameters.builder("npx") .args("-y", "@modelcontextprotocol/server-everything", "dir") .build(); From 474641e08fd91bc17e3c5b560e0c5533907e29c1 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 18 Mar 2025 12:11:17 +0100 Subject: [PATCH 15/20] feat(server): Make toSpecification methods public in McpServerFeatures Changed visibility of toSpecification() methods from package-private to public in all registration classes: - AsyncToolRegistration - AsyncResourceRegistration - AsyncPromptRegistration - SyncToolRegistration - SyncResourceRegistration - SyncPromptRegistration Signed-off-by: Christian Tzolov --- .../server/McpServerFeatures.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java index 2ba12f2a..c8334bb4 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -459,7 +459,7 @@ static AsyncToolRegistration fromSync(SyncToolRegistration tool) { map -> Mono.fromCallable(() -> tool.call().apply(map)).subscribeOn(Schedulers.boundedElastic())); } - AsyncToolSpecification toSpecification() { + public AsyncToolSpecification toSpecification() { return new AsyncToolSpecification(tool(), (exchange, map) -> call.apply(map)); } } @@ -505,7 +505,7 @@ static AsyncResourceRegistration fromSync(SyncResourceRegistration resource) { .subscribeOn(Schedulers.boundedElastic())); } - AsyncResourceSpecification toSpecification() { + public AsyncResourceSpecification toSpecification() { return new AsyncResourceSpecification(resource(), (exchange, request) -> readHandler.apply(request)); } } @@ -554,7 +554,7 @@ static AsyncPromptRegistration fromSync(SyncPromptRegistration prompt) { .subscribeOn(Schedulers.boundedElastic())); } - AsyncPromptSpecification toSpecification() { + public AsyncPromptSpecification toSpecification() { return new AsyncPromptSpecification(prompt(), (exchange, request) -> promptHandler.apply(request)); } } @@ -597,7 +597,7 @@ AsyncPromptSpecification toSpecification() { @Deprecated public record SyncToolRegistration(McpSchema.Tool tool, Function, McpSchema.CallToolResult> call) { - SyncToolSpecification toSpecification() { + public SyncToolSpecification toSpecification() { return new SyncToolSpecification(tool, (exchange, map) -> call.apply(map)); } } @@ -632,7 +632,7 @@ SyncToolSpecification toSpecification() { @Deprecated public record SyncResourceRegistration(McpSchema.Resource resource, Function readHandler) { - SyncResourceSpecification toSpecification() { + public SyncResourceSpecification toSpecification() { return new SyncResourceSpecification(resource, (exchange, request) -> readHandler.apply(request)); } } @@ -670,7 +670,7 @@ SyncResourceSpecification toSpecification() { @Deprecated public record SyncPromptRegistration(McpSchema.Prompt prompt, Function promptHandler) { - SyncPromptSpecification toSpecification() { + public SyncPromptSpecification toSpecification() { return new SyncPromptSpecification(prompt, (exchange, request) -> promptHandler.apply(request)); } } From 6da69b45f8b1f4374c6804cd4e2e32827ff9c3f7 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Wed, 19 Mar 2025 15:51:40 +0100 Subject: [PATCH 16/20] refactor(webflux): refactor WebFluxSseIntegrationTests Signed-off-by: Christian Tzolov --- .../WebFluxSseIntegrationTests.java | 292 +++++++++--------- 1 file changed, 150 insertions(+), 142 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 43253338..57bcd191 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -16,7 +16,6 @@ import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServerFeatures; -import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; @@ -31,9 +30,9 @@ import io.modelcontextprotocol.spec.McpSchema.Tool; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Mono; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; import reactor.test.StepVerifier; @@ -45,8 +44,8 @@ import org.springframework.web.reactive.function.server.RouterFunctions; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.awaitility.Awaitility.await; +import static org.mockito.Mockito.mock; public class WebFluxSseIntegrationTests { @@ -85,109 +84,100 @@ public void after() { // --------------------------------------- // Sampling Tests // --------------------------------------- - // TODO implement within a tool execution - // @Test - // void testCreateMessageWithoutInitialization() { - // var mcpAsyncServer = - // McpServer.async(mcpServerTransportProvider).serverInfo("test-server", - // "1.0.0").build(); - // - // var messages = List - // .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new - // McpSchema.TextContent("Test message"))); - // var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - // - // var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - // McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), - // Map.of()); - // - // StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error - // -> { - // assertThat(error).isInstanceOf(McpError.class) - // .hasMessage("Client must be initialized. Call the initialize method first!"); - // }); - // } - // - // @ParameterizedTest(name = "{0} : {displayName} ") - // @ValueSource(strings = { "httpclient", "webflux" }) - // void testCreateMessageWithoutSamplingCapabilities(String clientType) { - // - // var mcpAsyncServer = - // McpServer.async(mcpServerTransportProvider).serverInfo("test-server", - // "1.0.0").build(); - // - // var clientBuilder = clientBulders.get(clientType); - // - // var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", - // "0.0.0")).build(); - // - // InitializeResult initResult = client.initialize(); - // assertThat(initResult).isNotNull(); - // - // var messages = List - // .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new - // McpSchema.TextContent("Test message"))); - // var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - // - // var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - // McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), - // Map.of()); - // - // StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error - // -> { - // assertThat(error).isInstanceOf(McpError.class) - // .hasMessage("Client must be configured with sampling capabilities"); - // }); - // } - // - // @ParameterizedTest(name = "{0} : {displayName} ") - // @ValueSource(strings = { "httpclient", "webflux" }) - // void testCreateMessageSuccess(String clientType) throws InterruptedException { - // - // var clientBuilder = clientBulders.get(clientType); - // - // var mcpAsyncServer = - // McpServer.async(mcpServerTransportProvider).serverInfo("test-server", - // "1.0.0").build(); - // - // Function samplingHandler = request -> { - // assertThat(request.messages()).hasSize(1); - // assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - // - // return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test - // message"), "MockModelName", - // CreateMessageResult.StopReason.STOP_SEQUENCE); - // }; - // - // var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", - // "0.0.0")) - // .capabilities(ClientCapabilities.builder().sampling().build()) - // .sampling(samplingHandler) - // .build(); - // - // InitializeResult initResult = client.initialize(); - // assertThat(initResult).isNotNull(); - // - // var messages = List - // .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new - // McpSchema.TextContent("Test message"))); - // var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - // - // var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - // McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), - // Map.of()); - // - // StepVerifier.create(mcpAsyncServer.createMessage(request)).consumeNextWith(result - // -> { - // assertThat(result).isNotNull(); - // assertThat(result.role()).isEqualTo(Role.USER); - // assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - // assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test - // message"); - // assertThat(result.model()).isEqualTo("MockModelName"); - // assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - // }).verifyComplete(); - // } + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageWithoutSamplingCapabilities(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); + + return Mono.just(mock(CallToolResult.class)); + }); + + McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageSuccess(String clientType) throws InterruptedException { + + // Client + var clientBuilder = clientBulders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var messages = List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message"))); + var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + + var craeteMessageRequest = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, + McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), + Map.of()); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } // --------------------------------------- // Roots Tests @@ -238,43 +228,44 @@ void testRootsSuccess(String clientType) { mcpServer.close(); } - // @ParameterizedTest(name = "{0} : {displayName} ") - // @ValueSource(strings = { "httpclient", "webflux" }) - // void testRootsWithoutCapability(String clientType) { - // var clientBuilder = clientBulders.get(clientType); - // AtomicReference errorRef = new AtomicReference<>(); - // - // var mcpServer = - // McpServer.sync(mcpServerTransportProvider) - // // TODO: implement tool handling and try to list roots - // .tool(tool, (exchange, args) -> { - // try { - // exchange.listRoots(); - // } catch (Exception e) { - // errorRef.set(e); - // } - // }).build(); - // - // // Create client without roots capability - // var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()) // - // No - // // roots - // // capability - // .build(); - // - // InitializeResult initResult = mcpClient.initialize(); - // assertThat(initResult).isNotNull(); - // - // assertThat(errorRef.get()).isInstanceOf(McpError.class).hasMessage("Roots not - // supported"); - // - // mcpClient.close(); - // mcpServer.close(); - // } + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsWithoutCapability(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.listRoots(); // try to list roots + + return mock(CallToolResult.class); + }); + + var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { + }).tools(tool).build(); + + // Create client without roots capability + // No roots capability + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build(); + + assertThat(mcpClient.initialize()).isNotNull(); + + // Attempt to list roots should fail + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } + + mcpClient.close(); + mcpServer.close(); + } @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsWithEmptyRootsList(String clientType) { + void testRootsNotifciationWithEmptyRootsList(String clientType) { var clientBuilder = clientBulders.get(clientType); AtomicReference> rootsRef = new AtomicReference<>(); @@ -474,8 +465,8 @@ void testToolListChangeHandlingSuccess(String clientType) { }); // Add a new tool - McpServerFeatures.SyncToolRegistration tool2 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), request -> callResponse); + McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), (exchange, request) -> callResponse); mcpServer.addTool(tool2); @@ -487,4 +478,21 @@ void testToolListChangeHandlingSuccess(String clientType) { mcpServer.close(); } + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testInitialize(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); + + var mcpClient = clientBuilder.build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.close(); + mcpServer.close(); + } + } From 8704863f3b621b4926d79b52197fa1e1a3a69bcd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Tue, 18 Mar 2025 12:40:34 +0100 Subject: [PATCH 17/20] Add documentation --- .../WebFluxSseServerTransportProvider.java | 68 ++------ .../server/McpAsyncServer.java | 161 ++++-------------- .../server/McpAsyncServerExchange.java | 44 ++++- .../server/McpServer.java | 158 ++++++++++++----- .../server/McpServerFeatures.java | 38 +++-- .../server/McpSyncServer.java | 10 ++ .../server/McpSyncServerExchange.java | 55 +++++- .../spec/McpClientTransport.java | 6 + .../spec/McpServerSession.java | 88 ++++++++++ .../spec/McpServerTransport.java | 6 + .../spec/McpServerTransportProvider.java | 53 ++++-- .../modelcontextprotocol/spec/McpSession.java | 15 +- 12 files changed, 435 insertions(+), 267 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 13f5da31..cf3eeae0 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -38,8 +38,9 @@ *

      * Key features: *

        - *
      • Implements the {@link ServerMcpTransport} interface for MCP server transport - * functionality
      • + *
      • Implements the {@link McpServerTransportProvider} interface that allows managing + * {@link McpServerSession} instances and enabling their communication with the + * {@link McpServerTransport} abstraction.
      • *
      • Uses WebFlux for non-blocking request handling and SSE support
      • *
      • Maintains client sessions for reliable message delivery
      • *
      • Supports graceful shutdown with session cleanup
      • @@ -55,12 +56,13 @@ * *

        * This implementation is thread-safe and can handle multiple concurrent client - * connections. It uses {@link ConcurrentHashMap} for session management and Reactor's - * {@link Sinks} for thread-safe message broadcasting. + * connections. It uses {@link ConcurrentHashMap} for session management and Project + * Reactor's non-blocking APIs for message processing and delivery. * * @author Christian Tzolov * @author Alexandros Pappas - * @see ServerMcpTransport + * @author Dariusz Jędrzejczyk + * @see McpServerTransport * @see ServerSentEvent */ public class WebFluxSseServerTransportProvider implements McpServerTransportProvider { @@ -103,7 +105,7 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv private volatile boolean isClosing = false; /** - * Constructs a new WebFlux SSE server transport instance. + * Constructs a new WebFlux SSE server transport provider instance. * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization * of MCP messages. Must not be null. * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC @@ -126,8 +128,8 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa } /** - * Constructs a new WebFlux SSE server transport instance with the default SSE - * endpoint. + * Constructs a new WebFlux SSE server transport provider instance with the default + * SSE endpoint. * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization * of MCP messages. Must not be null. * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC @@ -179,8 +181,10 @@ public Mono notifyClients(String method, Map params) { .then(); } + // FIXME: This javadoc makes claims about using isClosing flag but it's not actually + // doing that. /** - * Initiates a graceful shutdown of the transport. This method ensures all active + * Initiates a graceful shutdown of all the sessions. This method ensures all active * sessions are properly closed and cleaned up. * *

        @@ -220,18 +224,8 @@ public RouterFunction getRouterFunction() { /** * Handles new SSE connection requests from clients. Creates a new session for each * connection and sets up the SSE event stream. - * - *

        - * The handler performs the following steps: - *

          - *
        • Generates a unique session ID
        • - *
        • Creates a new ClientSession instance
        • - *
        • Sends the message endpoint URI as an initial event
        • - *
        • Sets up message forwarding for the session
        • - *
        • Handles connection cleanup on completion or errors
        • - *
        * @param request The incoming server request - * @return A response with the SSE event stream + * @return A Mono which emits a response with the SSE event stream */ private Mono handleSseConnection(ServerRequest request) { if (isClosing) { @@ -275,7 +269,7 @@ private Mono handleSseConnection(ServerRequest request) { *
      • Handles various error conditions with appropriate error responses
      • *
      * @param request The incoming server request containing the JSON-RPC message - * @return A response indicating the message processing result + * @return A Mono emitting the response indicating the message processing result */ private Mono handleMessage(ServerRequest request) { if (isClosing) { @@ -307,38 +301,6 @@ private Mono handleMessage(ServerRequest request) { }); } - /* - * Current: - * - * framework layer: var transport = new WebFluxSseServerTransport(objectMapper, - * "/mcp", "/sse"); McpServer.async(ServerMcpTransport transport) - * - * client connects -> WebFluxSseServerTransport creates a: - var sessionTransport = - * WebFluxMcpSessionTransport - ServerMcpSession(sessionId, sessionTransport) - * - * WebFluxSseServerTransport IS_A ServerMcpTransport IS_A McpTransport - * WebFluxMcpSessionTransport IS_A ServerMcpSessionTransport IS_A McpTransport - * - * McpTransport contains connect() which should be removed ClientMcpTransport should - * have connect() ServerMcpTransport should have setSessionFactory() - * - * Possible Future: var transportProvider = new - * WebFluxSseServerTransport(objectMapper, "/mcp", "/sse"); WebFluxSseServerTransport - * IS_A ServerMcpTransportProvider ? ServerMcpTransportProvider creates - * ServerMcpTransport - * - * // disadvantage - too much breaks, e.g. McpServer.async(ServerMcpTransportProvider - * transportProvider) - * - * // advantage - * - * ClientMcpTransport and ServerMcpTransport BOTH represent 1:1 relationship - * - * - * - * - */ - private class WebFluxMcpSessionTransport implements McpServerTransport { private final FluxSink> sink; diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 0610e7bf..44536816 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -89,6 +89,9 @@ public class McpAsyncServer { * Create a new McpAsyncServer with the given transport and capabilities. * @param mcpTransport The transport layer implementation for MCP communication. * @param features The MCP server supported features. + * @deprecated This constructor will beremoved in 0.9.0. Use + * {@link #McpAsyncServer(McpServerTransportProvider, ObjectMapper, McpServerFeatures.Async)} + * instead. */ @Deprecated McpAsyncServer(ServerMcpTransport mcpTransport, McpServerFeatures.Async features) { @@ -96,10 +99,11 @@ public class McpAsyncServer { } /** - * Create a new McpAsyncServer with the given transport and capabilities. + * Create a new McpAsyncServer with the given transport provider and capabilities. * @param mcpTransportProvider The transport layer implementation for MCP * communication. * @param features The MCP server supported features. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization */ McpAsyncServer(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, McpServerFeatures.Async features) { @@ -125,7 +129,8 @@ public McpSchema.Implementation getServerInfo() { /** * Get the client capabilities that define the supported features and functionality. * @return The client capabilities - * @deprecated This will be removed in 0.9.0 + * @deprecated This will be removed in 0.9.0. Use + * {@link McpAsyncServerExchange#getClientCapabilities()}. */ @Deprecated public ClientCapabilities getClientCapabilities() { @@ -135,7 +140,8 @@ public ClientCapabilities getClientCapabilities() { /** * Get the client implementation information. * @return The client implementation details - * @deprecated This will be removed in 0.9.0 + * @deprecated This will be removed in 0.9.0. Use + * {@link McpAsyncServerExchange#getClientInfo()}. */ @Deprecated public McpSchema.Implementation getClientInfo() { @@ -160,6 +166,8 @@ public void close() { /** * Retrieves the list of all roots provided by the client. * @return A Mono that emits the list of roots result. + * @deprecated This will be removed in 0.9.0. Use + * {@link McpAsyncServerExchange#listRoots()}. */ @Deprecated public Mono listRoots() { @@ -170,7 +178,8 @@ public Mono listRoots() { * Retrieves a paginated list of roots provided by the server. * @param cursor Optional pagination cursor from a previous list request * @return A Mono that emits the list of roots result containing - * @deprecated This will be removed in 0.9.0 + * @deprecated This will be removed in 0.9.0. Use + * {@link McpAsyncServerExchange#listRoots(String)}. */ @Deprecated public Mono listRoots(String cursor) { @@ -339,7 +348,8 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN * @see Sampling * Specification - * @deprecated This will be removed in 0.9.0 + * @deprecated This will be removed in 0.9.0. Use + * {@link McpAsyncServerExchange#createMessage(McpSchema.CreateMessageRequest)}. */ @Deprecated public Mono createMessage(McpSchema.CreateMessageRequest createMessageRequest) { @@ -365,9 +375,6 @@ private static class AsyncServerImpl extends McpAsyncServer { private final McpSchema.Implementation serverInfo; - /** - * Thread-safe list of tool handlers that can be modified at runtime. - */ private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); @@ -378,17 +385,8 @@ private static class AsyncServerImpl extends McpAsyncServer { private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; - /** - * Supported protocol versions. - */ private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); - /** - * Create a new McpAsyncServer with the given transport and capabilities. - * @param mcpTransportProvider The transport layer implementation for MCP - * communication. - * @param features The MCP server supported features. - */ AsyncServerImpl(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, McpServerFeatures.Async features) { this.mcpTransportProvider = mcpTransportProvider; @@ -485,71 +483,43 @@ private Mono asyncInitializeRequestHandler( }); } - /** - * Get the server capabilities that define the supported features and - * functionality. - * @return The server capabilities - */ public McpSchema.ServerCapabilities getServerCapabilities() { return this.serverCapabilities; } - /** - * Get the server implementation information. - * @return The server implementation details - */ public McpSchema.Implementation getServerInfo() { return this.serverInfo; } - /** - * Get the client capabilities that define the supported features and - * functionality. - * @return The client capabilities - */ + @Override @Deprecated public ClientCapabilities getClientCapabilities() { throw new IllegalStateException("This method is deprecated and should not be called"); } - /** - * Get the client implementation information. - * @return The client implementation details - */ + @Override @Deprecated public McpSchema.Implementation getClientInfo() { throw new IllegalStateException("This method is deprecated and should not be called"); } - /** - * Gracefully closes the server, allowing any in-progress operations to complete. - * @return A Mono that completes when the server has been closed - */ + @Override public Mono closeGracefully() { return this.mcpTransportProvider.closeGracefully(); } - /** - * Close the server immediately. - */ + @Override public void close() { this.mcpTransportProvider.close(); } - /** - * Retrieves the list of all roots provided by the client. - * @return A Mono that emits the list of roots result. - */ + @Override @Deprecated public Mono listRoots() { return this.listRoots(null); } - /** - * Retrieves a paginated list of roots provided by the server. - * @param cursor Optional pagination cursor from a previous list request - * @return A Mono that emits the list of roots result containing - */ + @Override @Deprecated public Mono listRoots(String cursor) { return Mono.error(new RuntimeException("Not implemented")); @@ -571,11 +541,7 @@ private McpServerSession.NotificationHandler asyncRootsListChangedNotificationHa // Tool Management // --------------------------------------- - /** - * Add a new tool specification at runtime. - * @param toolSpecification The tool specification to add - * @return Mono that completes when clients have been notified of the change - */ + @Override public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) { if (toolSpecification == null) { return Mono.error(new McpError("Tool specification must not be null")); @@ -612,11 +578,7 @@ public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistrati return this.addTool(toolRegistration.toSpecification()); } - /** - * Remove a tool handler at runtime. - * @param toolName The name of the tool handler to remove - * @return Mono that completes when clients have been notified of the change - */ + @Override public Mono removeTool(String toolName) { if (toolName == null) { return Mono.error(new McpError("Tool name must not be null")); @@ -639,10 +601,7 @@ public Mono removeTool(String toolName) { }); } - /** - * Notifies clients that the list of available tools has changed. - * @return A Mono that completes when all clients have been notified - */ + @Override public Mono notifyToolsListChanged() { return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); } @@ -678,11 +637,7 @@ private McpServerSession.RequestHandler toolsCallRequestHandler( // Resource Management // --------------------------------------- - /** - * Add a new resource handler at runtime. - * @param resourceSpecification The resource handler to add - * @return Mono that completes when clients have been notified of the change - */ + @Override public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceSpecification) { if (resourceSpecification == null || resourceSpecification.resource() == null) { return Mono.error(new McpError("Resource must not be null")); @@ -710,11 +665,7 @@ public Mono addResource(McpServerFeatures.AsyncResourceRegistration resour return this.addResource(resourceHandler.toSpecification()); } - /** - * Remove a resource handler at runtime. - * @param resourceUri The URI of the resource handler to remove - * @return Mono that completes when clients have been notified of the change - */ + @Override public Mono removeResource(String resourceUri) { if (resourceUri == null) { return Mono.error(new McpError("Resource URI must not be null")); @@ -736,10 +687,7 @@ public Mono removeResource(String resourceUri) { }); } - /** - * Notifies clients that the list of available resources has changed. - * @return A Mono that completes when all clients have been notified - */ + @Override public Mono notifyResourcesListChanged() { return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); } @@ -778,11 +726,7 @@ private McpServerSession.RequestHandler resourcesR // Prompt Management // --------------------------------------- - /** - * Add a new prompt handler at runtime. - * @param promptSpecification The prompt handler to add - * @return Mono that completes when clients have been notified of the change - */ + @Override public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) { if (promptSpecification == null) { return Mono.error(new McpError("Prompt specification must not be null")); @@ -816,11 +760,7 @@ public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegi return this.addPrompt(promptRegistration.toSpecification()); } - /** - * Remove a prompt handler at runtime. - * @param promptName The name of the prompt handler to remove - * @return Mono that completes when clients have been notified of the change - */ + @Override public Mono removePrompt(String promptName) { if (promptName == null) { return Mono.error(new McpError("Prompt name must not be null")); @@ -845,10 +785,7 @@ public Mono removePrompt(String promptName) { }); } - /** - * Notifies clients that the list of available prompts has changed. - * @return A Mono that completes when all clients have been notified - */ + @Override public Mono notifyPromptsListChanged() { return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); } @@ -889,12 +826,7 @@ private McpServerSession.RequestHandler promptsGetReq // Logging Management // --------------------------------------- - /** - * Send a logging message notification to all connected clients. Messages below - * the current minimum logging level will be filtered out. - * @param loggingMessageNotification The logging message to send - * @return A Mono that completes when the notification has been sent - */ + @Override public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { if (loggingMessageNotification == null) { @@ -912,11 +844,6 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_MESSAGE, params); } - /** - * Handles requests to set the minimum logging level. Messages below this level - * will not be sent. - * @return A handler that processes logging level change requests - */ private McpServerSession.RequestHandler setLoggerRequestHandler() { return (exchange, params) -> { this.minLoggingLevel = objectMapper.convertValue(params, new TypeReference() { @@ -930,35 +857,13 @@ private McpServerSession.RequestHandler setLoggerRequestHandler() { // Sampling // --------------------------------------- - /** - * Create a new message using the sampling capabilities of the client. The Model - * Context Protocol (MCP) provides a standardized way for servers to request LLM - * sampling (“completions” or “generations”) from language models via clients. - * This flow allows clients to maintain control over model access, selection, and - * permissions while enabling servers to leverage AI capabilities—with no server - * API keys necessary. Servers can request text or image-based interactions and - * optionally include context from MCP servers in their prompts. - * @param createMessageRequest The request to create a new message - * @return A Mono that completes when the message has been created - * @throws McpError if the client has not been initialized or does not support - * sampling capabilities - * @throws McpError if the client does not support the createMessage method - * @see McpSchema.CreateMessageRequest - * @see McpSchema.CreateMessageResult - * @see Sampling - * Specification - */ + @Override @Deprecated public Mono createMessage(McpSchema.CreateMessageRequest createMessageRequest) { return Mono.error(new RuntimeException("Not implemented")); } - /** - * This method is package-private and used for test only. Should not be called by - * user code. - * @param protocolVersions the Client supported protocol versions. - */ + @Override void setProtocolVersions(List protocolVersions) { this.protocolVersions = protocolVersions; } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index 8959c293..65862844 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -6,6 +6,12 @@ import io.modelcontextprotocol.spec.McpServerSession; import reactor.core.publisher.Mono; +/** + * Represents an asynchronous exchange with a Model Context Protocol (MCP) client. The + * exchange provides methods to interact with the client and query its capabilities. + * + * @author Dariusz Jędrzejczyk + */ public class McpAsyncServerExchange { private final McpServerSession session; @@ -14,6 +20,19 @@ public class McpAsyncServerExchange { private final McpSchema.Implementation clientInfo; + private static final TypeReference CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() { + }; + + private static final TypeReference LIST_ROOTS_RESULT_TYPE_REF = new TypeReference<>() { + }; + + /** + * Create a new asynchronous exchange with the client. + * @param session The server session representing a 1-1 interaction. + * @param clientCapabilities The client capabilities that define the supported + * features and functionality. + * @param clientInfo The client implementation information. + */ public McpAsyncServerExchange(McpServerSession session, McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { this.session = session; @@ -21,8 +40,21 @@ public McpAsyncServerExchange(McpServerSession session, McpSchema.ClientCapabili this.clientInfo = clientInfo; } - private static final TypeReference CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() { - }; + /** + * Get the client capabilities that define the supported features and functionality. + * @return The client capabilities + */ + public McpSchema.ClientCapabilities getClientCapabilities() { + return this.clientCapabilities; + } + + /** + * Get the client implementation information. + * @return The client implementation details + */ + public McpSchema.Implementation getClientInfo() { + return this.clientInfo; + } /** * Create a new message using the sampling capabilities of the client. The Model @@ -34,9 +66,6 @@ public McpAsyncServerExchange(McpServerSession session, McpSchema.ClientCapabili * include context from MCP servers in their prompts. * @param createMessageRequest The request to create a new message * @return A Mono that completes when the message has been created - * @throws McpError if the client has not been initialized or does not support - * sampling capabilities - * @throws McpError if the client does not support the createMessage method * @see McpSchema.CreateMessageRequest * @see McpSchema.CreateMessageResult * @see createMessage(McpSchema.CreateMessage CREATE_MESSAGE_RESULT_TYPE_REF); } - private static final TypeReference LIST_ROOTS_RESULT_TYPE_REF = new TypeReference<>() { - }; - /** * Retrieves the list of all roots provided by the client. * @return A Mono that emits the list of roots result. @@ -66,7 +92,7 @@ public Mono listRoots() { } /** - * Retrieves a paginated list of roots provided by the server. + * Retrieves a paginated list of roots provided by the client. * @param cursor Optional pagination cursor from a previous list request * @return A Mono that emits the list of roots result containing */ diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index 81a0ed44..d8dfcb01 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -18,7 +18,6 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerTransportProvider; -import io.modelcontextprotocol.spec.McpTransport; import io.modelcontextprotocol.spec.ServerMcpTransport; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; @@ -55,45 +54,50 @@ *

      * The class provides factory methods to create either: *

        - *
      • {@link McpAsyncServer} for non-blocking operations with CompletableFuture responses + *
      • {@link McpAsyncServer} for non-blocking operations with reactive responses *
      • {@link McpSyncServer} for blocking operations with direct responses *
      * *

      * Example of creating a basic synchronous server:

      {@code
      - * McpServer.sync(transport)
      + * McpServer.sync(transportProvider)
        *     .serverInfo("my-server", "1.0.0")
        *     .tool(new Tool("calculator", "Performs calculations", schema),
      - *           args -> new CallToolResult("Result: " + calculate(args)))
      + *           (exchange, args) -> new CallToolResult("Result: " + calculate(args)))
        *     .build();
        * }
      * * Example of creating a basic asynchronous server:
      {@code
      - * McpServer.async(transport)
      + * McpServer.async(transportProvider)
        *     .serverInfo("my-server", "1.0.0")
        *     .tool(new Tool("calculator", "Performs calculations", schema),
      - *           args -> Mono.just(new CallToolResult("Result: " + calculate(args))))
      + *           (exchange, args) -> Mono.fromSupplier(() -> calculate(args))
      + *               .map(result -> new CallToolResult("Result: " + result)))
        *     .build();
        * }
      * *

      * Example with comprehensive asynchronous configuration:

      {@code
      - * McpServer.async(transport)
      + * McpServer.async(transportProvider)
        *     .serverInfo("advanced-server", "2.0.0")
        *     .capabilities(new ServerCapabilities(...))
        *     // Register tools
        *     .tools(
      - *         new McpServerFeatures.AsyncToolRegistration(calculatorTool,
      - *             args -> Mono.just(new CallToolResult("Result: " + calculate(args)))),
      - *         new McpServerFeatures.AsyncToolRegistration(weatherTool,
      - *             args -> Mono.just(new CallToolResult("Weather: " + getWeather(args))))
      + *         new McpServerFeatures.AsyncToolSpecification(calculatorTool,
      + *             (exchange, args) -> Mono.fromSupplier(() -> calculate(args))
      + *                 .map(result -> new CallToolResult("Result: " + result))),
      + *         new McpServerFeatures.AsyncToolSpecification(weatherTool,
      + *             (exchange, args) -> Mono.fromSupplier(() -> getWeather(args))
      + *                 .map(result -> new CallToolResult("Weather: " + result)))
        *     )
        *     // Register resources
        *     .resources(
      - *         new McpServerFeatures.AsyncResourceRegistration(fileResource,
      - *             req -> Mono.just(new ReadResourceResult(readFile(req)))),
      - *         new McpServerFeatures.AsyncResourceRegistration(dbResource,
      - *             req -> Mono.just(new ReadResourceResult(queryDb(req))))
      + *         new McpServerFeatures.AsyncResourceSpecification(fileResource,
      + *             (exchange, req) -> Mono.fromSupplier(() -> readFile(req))
      + *                 .map(ReadResourceResult::new)),
      + *         new McpServerFeatures.AsyncResourceSpecification(dbResource,
      + *             (exchange, req) -> Mono.fromSupplier(() -> queryDb(req))
      + *                 .map(ReadResourceResult::new))
        *     )
        *     // Add resource templates
        *     .resourceTemplates(
      @@ -102,10 +106,12 @@
        *     )
        *     // Register prompts
        *     .prompts(
      - *         new McpServerFeatures.AsyncPromptRegistration(analysisPrompt,
      - *             req -> Mono.just(new GetPromptResult(generateAnalysisPrompt(req)))),
      + *         new McpServerFeatures.AsyncPromptSpecification(analysisPrompt,
      + *             (exchange, req) -> Mono.fromSupplier(() -> generateAnalysisPrompt(req))
      + *                 .map(GetPromptResult::new)),
        *         new McpServerFeatures.AsyncPromptRegistration(summaryPrompt,
      - *             req -> Mono.just(new GetPromptResult(generateSummaryPrompt(req))))
      + *             (exchange, req) -> Mono.fromSupplier(() -> generateSummaryPrompt(req))
      + *                 .map(GetPromptResult::new))
        *     )
        *     .build();
        * }
      @@ -114,15 +120,27 @@ * @author Dariusz Jędrzejczyk * @see McpAsyncServer * @see McpSyncServer - * @see McpTransport + * @see McpServerTransportProvider */ public interface McpServer { /** * Starts building a synchronous MCP server that provides blocking operations. - * Synchronous servers process each request to completion before handling the next - * one, making them simpler to implement but potentially less performant for - * concurrent operations. + * Synchronous servers block the current Thread's execution upon each request before + * giving the control back to the caller, making them simpler to implement but + * potentially less scalable for concurrent operations. + * @param transportProvider The transport layer implementation for MCP communication. + * @return A new instance of {@link SyncSpecification} for configuring the server. + */ + static SyncSpecification sync(McpServerTransportProvider transportProvider) { + return new SyncSpecification(transportProvider); + } + + /** + * Starts building a synchronous MCP server that provides blocking operations. + * Synchronous servers block the current Thread's execution upon each request before + * giving the control back to the caller, making them simpler to implement but + * potentially less scalable for concurrent operations. * @param transport The transport layer implementation for MCP communication * @return A new instance of {@link SyncSpec} for configuring the server. * @deprecated This method will be removed in 0.9.0. Use @@ -133,15 +151,23 @@ static SyncSpec sync(ServerMcpTransport transport) { return new SyncSpec(transport); } - static SyncSpecification sync(McpServerTransportProvider transportProvider) { - return new SyncSpecification(transportProvider); + /** + * Starts building an asynchronous MCP server that provides non-blocking operations. + * Asynchronous servers can handle multiple requests concurrently on a single Thread + * using a functional paradigm with non-blocking server transports, making them more + * scalable for high-concurrency scenarios but more complex to implement. + * @param transportProvider The transport layer implementation for MCP communication. + * @return A new instance of {@link AsyncSpecification} for configuring the server. + */ + static AsyncSpecification async(McpServerTransportProvider transportProvider) { + return new AsyncSpecification(transportProvider); } /** - * Starts building an asynchronous MCP server that provides blocking operations. - * Asynchronous servers can handle multiple requests concurrently using a functional - * paradigm with non-blocking server transports, making them more efficient for - * high-concurrency scenarios but more complex to implement. + * Starts building an asynchronous MCP server that provides non-blocking operations. + * Asynchronous servers can handle multiple requests concurrently on a single Thread + * using a functional paradigm with non-blocking server transports, making them more + * scalable for high-concurrency scenarios but more complex to implement. * @param transport The transport layer implementation for MCP communication * @return A new instance of {@link AsyncSpec} for configuring the server. * @deprecated This method will be removed in 0.9.0. Use @@ -152,10 +178,6 @@ static AsyncSpec async(ServerMcpTransport transport) { return new AsyncSpec(transport); } - static AsyncSpecification async(McpServerTransportProvider transportProvider) { - return new AsyncSpecification(transportProvider); - } - /** * Asynchronous server specification. */ @@ -248,8 +270,6 @@ public AsyncSpecification serverInfo(String name, String version) { *
    • Tool execution *
    • Resource access *
    • Prompt handling - *
    • Streaming responses - *
    • Batch operations *
    * @param serverCapabilities The server capabilities configuration. Must not be * null. @@ -257,6 +277,7 @@ public AsyncSpecification serverInfo(String name, String version) { * @throws IllegalArgumentException if serverCapabilities is null */ public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + Assert.notNull(serverCapabilities, "Server capabilities must not be null"); this.serverCapabilities = serverCapabilities; return this; } @@ -270,12 +291,16 @@ public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabi * Example usage:
    {@code
     		 * .tool(
     		 *     new Tool("calculator", "Performs calculations", schema),
    -		 *     args -> Mono.just(new CallToolResult("Result: " + calculate(args)))
    +		 *     (exchange, args) -> Mono.fromSupplier(() -> calculate(args))
    +		 *         .map(result -> new CallToolResult("Result: " + result))
     		 * )
     		 * }
    * @param tool The tool definition including name, description, and schema. Must * not be null. * @param handler The function that implements the tool's logic. Must not be null. + * The function's first argument is an {@link McpAsyncServerExchange} upon which + * the server can interact with the connected client. The second argument is the + * map of arguments passed to the tool. * @return This builder instance for method chaining * @throws IllegalArgumentException if tool or handler is null */ @@ -323,6 +348,7 @@ public AsyncSpecification tools(List t * @see #tools(List) */ public AsyncSpecification tools(McpServerFeatures.AsyncToolSpecification... toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); for (McpServerFeatures.AsyncToolSpecification tool : toolSpecifications) { this.tools.add(tool); } @@ -402,9 +428,11 @@ public AsyncSpecification resources(McpServerFeatures.AsyncResourceSpecification * @param resourceTemplates List of resource templates. If null, clears existing * templates. * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. * @see #resourceTemplates(ResourceTemplate...) */ public AsyncSpecification resourceTemplates(List resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); this.resourceTemplates.addAll(resourceTemplates); return this; } @@ -414,9 +442,11 @@ public AsyncSpecification resourceTemplates(List resourceTempl * alternative to {@link #resourceTemplates(List)}. * @param resourceTemplates The resource templates to set. * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. * @see #resourceTemplates(List) */ public AsyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); for (ResourceTemplate resourceTemplate : resourceTemplates) { this.resourceTemplates.add(resourceTemplate); } @@ -432,7 +462,8 @@ public AsyncSpecification resourceTemplates(ResourceTemplate... resourceTemplate * Example usage:
    {@code
     		 * .prompts(Map.of("analysis", new McpServerFeatures.AsyncPromptSpecification(
     		 *     new Prompt("analysis", "Code analysis template"),
    -		 *     request -> Mono.just(new GetPromptResult(generateAnalysisPrompt(request)))
    +		 *     request -> Mono.fromSupplier(() -> generateAnalysisPrompt(request))
    +		 *         .map(GetPromptResult::new)
     		 * )));
     		 * }
    * @param prompts Map of prompt name to specification. Must not be null. @@ -440,6 +471,7 @@ public AsyncSpecification resourceTemplates(ResourceTemplate... resourceTemplate * @throws IllegalArgumentException if prompts is null */ public AsyncSpecification prompts(Map prompts) { + Assert.notNull(prompts, "Prompts map must not be null"); this.prompts.putAll(prompts); return this; } @@ -453,6 +485,7 @@ public AsyncSpecification prompts(Map prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); for (McpServerFeatures.AsyncPromptSpecification prompt : prompts) { this.prompts.put(prompt.prompt().name(), prompt); } @@ -476,6 +509,7 @@ public AsyncSpecification prompts(List, Mono>> handlers) { @@ -519,13 +556,22 @@ public AsyncSpecification rootsChangeHandlers( * @param handlers The handlers to register. Must not be null. * @return This builder instance for method chaining * @throws IllegalArgumentException if consumers is null + * @see #rootsChangeHandlers(List) */ public AsyncSpecification rootsChangeHandlers( @SuppressWarnings("unchecked") BiFunction, Mono>... handlers) { + Assert.notNull(handlers, "Handlers list must not be null"); return this.rootsChangeHandlers(Arrays.asList(handlers)); } + /** + * Sets the object mapper to use for serializing and deserializing JSON messages. + * @param objectMapper the instance to use. Must not be null. + * @return This builder instance for method chaining. + * @throws IllegalArgumentException if objectMapper is null + */ public AsyncSpecification objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); this.objectMapper = objectMapper; return this; } @@ -533,7 +579,7 @@ public AsyncSpecification objectMapper(ObjectMapper objectMapper) { /** * Builds an asynchronous MCP server that provides non-blocking operations. * @return A new instance of {@link McpAsyncServer} configured with this builder's - * settings + * settings. */ public McpAsyncServer build() { var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, @@ -636,8 +682,6 @@ public SyncSpecification serverInfo(String name, String version) { *
  • Tool execution *
  • Resource access *
  • Prompt handling - *
  • Streaming responses - *
  • Batch operations * * @param serverCapabilities The server capabilities configuration. Must not be * null. @@ -645,6 +689,7 @@ public SyncSpecification serverInfo(String name, String version) { * @throws IllegalArgumentException if serverCapabilities is null */ public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + Assert.notNull(serverCapabilities, "Server capabilities must not be null"); this.serverCapabilities = serverCapabilities; return this; } @@ -658,12 +703,15 @@ public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabil * Example usage:
    {@code
     		 * .tool(
     		 *     new Tool("calculator", "Performs calculations", schema),
    -		 *     args -> new CallToolResult("Result: " + calculate(args))
    +		 *     (exchange, args) -> new CallToolResult("Result: " + calculate(args))
     		 * )
     		 * }
    * @param tool The tool definition including name, description, and schema. Must * not be null. * @param handler The function that implements the tool's logic. Must not be null. + * The function's first argument is an {@link McpSyncServerExchange} upon which + * the server can interact with the connected client. The second argument is the + * list of arguments passed to the tool. * @return This builder instance for method chaining * @throws IllegalArgumentException if tool or handler is null */ @@ -711,6 +759,7 @@ public SyncSpecification tools(List too * @see #tools(List) */ public SyncSpecification tools(McpServerFeatures.SyncToolSpecification... toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); for (McpServerFeatures.SyncToolSpecification tool : toolSpecifications) { this.tools.add(tool); } @@ -790,9 +839,11 @@ public SyncSpecification resources(McpServerFeatures.SyncResourceSpecification.. * @param resourceTemplates List of resource templates. If null, clears existing * templates. * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. * @see #resourceTemplates(ResourceTemplate...) */ public SyncSpecification resourceTemplates(List resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); this.resourceTemplates.addAll(resourceTemplates); return this; } @@ -802,9 +853,11 @@ public SyncSpecification resourceTemplates(List resourceTempla * alternative to {@link #resourceTemplates(List)}. * @param resourceTemplates The resource templates to set. * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null * @see #resourceTemplates(List) */ public SyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); for (ResourceTemplate resourceTemplate : resourceTemplates) { this.resourceTemplates.add(resourceTemplate); } @@ -821,7 +874,7 @@ public SyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates * Map prompts = new HashMap<>(); * prompts.put("analysis", new PromptSpecification( * new Prompt("analysis", "Code analysis template"), - * request -> new GetPromptResult(generateAnalysisPrompt(request)) + * (exchange, request) -> new GetPromptResult(generateAnalysisPrompt(request)) * )); * .prompts(prompts) * } @@ -830,6 +883,7 @@ public SyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates * @throws IllegalArgumentException if prompts is null */ public SyncSpecification prompts(Map prompts) { + Assert.notNull(prompts, "Prompts map must not be null"); this.prompts.putAll(prompts); return this; } @@ -843,6 +897,7 @@ public SyncSpecification prompts(Map prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); for (McpServerFeatures.SyncPromptSpecification prompt : prompts) { this.prompts.put(prompt.prompt().name(), prompt); } @@ -866,6 +921,7 @@ public SyncSpecification prompts(List * @throws IllegalArgumentException if prompts is null */ public SyncSpecification prompts(McpServerFeatures.SyncPromptSpecification... prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); for (McpServerFeatures.SyncPromptSpecification prompt : prompts) { this.prompts.put(prompt.prompt().name(), prompt); } @@ -876,7 +932,9 @@ public SyncSpecification prompts(McpServerFeatures.SyncPromptSpecification... pr * Registers a consumer that will be notified when the list of roots changes. This * is useful for updating resource availability dynamically, such as when new * files are added or removed. - * @param handler The handler to register. Must not be null. + * @param handler The handler to register. Must not be null. The function's first + * argument is an {@link McpSyncServerExchange} upon which the server can interact + * with the connected client. The second argument is the list of roots. * @return This builder instance for method chaining * @throws IllegalArgumentException if consumer is null */ @@ -893,6 +951,7 @@ public SyncSpecification rootsChangeHandler(BiConsumer>> handlers) { @@ -908,15 +967,22 @@ public SyncSpecification rootsChangeHandlers( * @param handlers The handlers to register. Must not be null. * @return This builder instance for method chaining * @throws IllegalArgumentException if consumers is null - * @deprecated This method will * be removed in 0.9.0. Use - * {@link #rootsChangeHandlers(BiConsumer[])}. + * @see #rootsChangeHandlers(List) */ public SyncSpecification rootsChangeHandlers( BiConsumer>... handlers) { + Assert.notNull(handlers, "Handlers list must not be null"); return this.rootsChangeHandlers(List.of(handlers)); } + /** + * Sets the object mapper to use for serializing and deserializing JSON messages. + * @param objectMapper the instance to use. Must not be null. + * @return This builder instance for method chaining. + * @throws IllegalArgumentException if objectMapper is null + */ public SyncSpecification objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); this.objectMapper = objectMapper; return this; } @@ -924,7 +990,7 @@ public SyncSpecification objectMapper(ObjectMapper objectMapper) { /** * Builds a synchronous MCP server that provides blocking operations. * @return A new instance of {@link McpSyncServer} configured with this builder's - * settings + * settings. */ public McpSyncServer build() { McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java index c8334bb4..5aeeadd7 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -202,14 +202,17 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se * ), * (exchange, args) -> { * String expr = (String) args.get("expression"); - * return Mono.just(new CallToolResult("Result: " + evaluate(expr))); + * return Mono.fromSupplier(() -> evaluate(expr)) + * .map(result -> new CallToolResult("Result: " + result)); * } * ) * } * * @param tool The tool definition including name, description, and parameter schema * @param call The function that implements the tool's logic, receiving arguments and - * returning results + * returning results. The function's first argument is an + * {@link McpAsyncServerExchange} upon which the server can interact with the + * connected client. The second arguments is a map of tool arguments. */ public record AsyncToolSpecification(McpSchema.Tool tool, BiFunction, Mono> call) { @@ -241,15 +244,17 @@ static AsyncToolSpecification fromSync(SyncToolSpecification tool) { * Example resource specification:
    {@code
     	 * new McpServerFeatures.AsyncResourceSpecification(
     	 *     new Resource("docs", "Documentation files", "text/markdown"),
    -	 *     (exchange, request) -> {
    -	 *         String content = readFile(request.getPath());
    -	 *         return Mono.just(new ReadResourceResult(content));
    -	 *     }
    +	 *     (exchange, request) ->
    +	 *         Mono.fromSupplier(() -> readFile(request.getPath()))
    +	 *             .map(ReadResourceResult::new)
     	 * )
     	 * }
    * * @param resource The resource definition including name, description, and MIME type - * @param readHandler The function that handles resource read requests + * @param readHandler The function that handles resource read requests. The function's + * first argument is an {@link McpAsyncServerExchange} upon which the server can + * interact with the connected client. The second arguments is a + * {@link io.modelcontextprotocol.spec.McpSchema.ReadResourceRequest}. */ public record AsyncResourceSpecification(McpSchema.Resource resource, BiFunction> readHandler) { @@ -292,7 +297,10 @@ static AsyncResourceSpecification fromSync(SyncResourceSpecification resource) { * * @param prompt The prompt definition including name and description * @param promptHandler The function that processes prompt requests and returns - * formatted templates + * formatted templates. The function's first argument is an + * {@link McpAsyncServerExchange} upon which the server can interact with the + * connected client. The second arguments is a + * {@link io.modelcontextprotocol.spec.McpSchema.GetPromptRequest}. */ public record AsyncPromptSpecification(McpSchema.Prompt prompt, BiFunction> promptHandler) { @@ -340,7 +348,9 @@ static AsyncPromptSpecification fromSync(SyncPromptSpecification prompt) { * * @param tool The tool definition including name, description, and parameter schema * @param call The function that implements the tool's logic, receiving arguments and - * returning results + * returning results. The function's first argument is an + * {@link McpSyncServerExchange} upon which the server can interact with the connected + * client. The second arguments is a map of arguments passed to the tool. */ public record SyncToolSpecification(McpSchema.Tool tool, BiFunction, McpSchema.CallToolResult> call) { @@ -369,7 +379,10 @@ public record SyncToolSpecification(McpSchema.Tool tool, * } * * @param resource The resource definition including name, description, and MIME type - * @param readHandler The function that handles resource read requests + * @param readHandler The function that handles resource read requests. The function's + * first argument is an {@link McpSyncServerExchange} upon which the server can + * interact with the connected client. The second arguments is a + * {@link io.modelcontextprotocol.spec.McpSchema.ReadResourceRequest}. */ public record SyncResourceSpecification(McpSchema.Resource resource, BiFunction readHandler) { @@ -401,7 +414,10 @@ public record SyncResourceSpecification(McpSchema.Resource resource, * * @param prompt The prompt definition including name and description * @param promptHandler The function that processes prompt requests and returns - * formatted templates + * formatted templates. The function's first argument is an + * {@link McpSyncServerExchange} upon which the server can interact with the connected + * client. The second arguments is a + * {@link io.modelcontextprotocol.spec.McpSchema.GetPromptRequest}. */ public record SyncPromptSpecification(McpSchema.Prompt prompt, BiFunction promptHandler) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java index bba5b059..60662d98 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java @@ -68,6 +68,8 @@ public McpSyncServer(McpAsyncServer asyncServer) { /** * Retrieves the list of all roots provided by the client. * @return The list of roots + * @deprecated This method will be removed in 0.9.0. Use + * {@link McpSyncServerExchange#listRoots()}. */ @Deprecated public McpSchema.ListRootsResult listRoots() { @@ -78,6 +80,8 @@ public McpSchema.ListRootsResult listRoots() { * Retrieves a paginated list of roots provided by the server. * @param cursor Optional pagination cursor from a previous list request * @return The list of roots + * @deprecated This method will be removed in 0.9.0. Use + * {@link McpSyncServerExchange#listRoots(String)}. */ @Deprecated public McpSchema.ListRootsResult listRoots(String cursor) { @@ -191,6 +195,8 @@ public McpSchema.Implementation getServerInfo() { /** * Get the client capabilities that define the supported features and functionality. * @return The client capabilities + * @deprecated This method will be removed in 0.9.0. Use + * {@link McpSyncServerExchange#getClientCapabilities()}. */ @Deprecated public ClientCapabilities getClientCapabilities() { @@ -200,6 +206,8 @@ public ClientCapabilities getClientCapabilities() { /** * Get the client implementation information. * @return The client implementation details + * @deprecated This method will be removed in 0.9.0. Use + * {@link McpSyncServerExchange#getClientInfo()}. */ @Deprecated public McpSchema.Implementation getClientInfo() { @@ -274,6 +282,8 @@ public McpAsyncServer getAsyncServer() { * @see
    Sampling * Specification + * @deprecated This method will be removed in 0.9.0. Use + * {@link McpSyncServerExchange#createMessage(McpSchema.CreateMessageRequest)}. */ @Deprecated public McpSchema.CreateMessageResult createMessage(McpSchema.CreateMessageRequest createMessageRequest) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java index 09d87111..f121db55 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java @@ -3,25 +3,74 @@ import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.spec.McpSchema; +/** + * Represents a synchronous exchange with a Model Context Protocol (MCP) client. The + * exchange provides methods to interact with the client and query its capabilities. + * + * @author Dariusz Jędrzejczyk + */ public class McpSyncServerExchange { private final McpAsyncServerExchange exchange; + /** + * Create a new synchronous exchange with the client using the provided asynchronous + * implementation as a delegate. + * @param exchange The asynchronous exchange to delegate to. + */ public McpSyncServerExchange(McpAsyncServerExchange exchange) { this.exchange = exchange; } + /** + * Get the client capabilities that define the supported features and functionality. + * @return The client capabilities + */ + public McpSchema.ClientCapabilities getClientCapabilities() { + return this.exchange.getClientCapabilities(); + } + + /** + * Get the client implementation information. + * @return The client implementation details + */ + public McpSchema.Implementation getClientInfo() { + return this.exchange.getClientInfo(); + } + + /** + * Create a new message using the sampling capabilities of the client. The Model + * Context Protocol (MCP) provides a standardized way for servers to request LLM + * sampling (“completions” or “generations”) from language models via clients. This + * flow allows clients to maintain control over model access, selection, and + * permissions while enabling servers to leverage AI capabilities—with no server API + * keys necessary. Servers can request text or image-based interactions and optionally + * include context from MCP servers in their prompts. + * @param createMessageRequest The request to create a new message + * @return A result containing the details of the sampling response + * @see McpSchema.CreateMessageRequest + * @see McpSchema.CreateMessageResult + * @see Sampling + * Specification + */ public McpSchema.CreateMessageResult createMessage(McpSchema.CreateMessageRequest createMessageRequest) { return this.exchange.createMessage(createMessageRequest).block(); } - private static final TypeReference LIST_ROOTS_RESULT_TYPE_REF = new TypeReference<>() { - }; - + /** + * Retrieves the list of all roots provided by the client. + * @return The list of roots result. + */ public McpSchema.ListRootsResult listRoots() { return this.exchange.listRoots().block(); } + /** + * Retrieves a paginated list of roots provided by the client. + * @param cursor Optional pagination cursor from a previous list request + * @return The list of roots result + */ public McpSchema.ListRootsResult listRoots(String cursor) { return this.exchange.listRoots(cursor).block(); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java index 63aa1dbf..45897965 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java @@ -7,6 +7,12 @@ import reactor.core.publisher.Mono; +/** + * Marker interface for the client-side MCP transport. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ public interface McpClientTransport extends ClientMcpTransport { @Override diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 8304abd6..bcdf2248 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -15,6 +15,10 @@ import reactor.core.publisher.MonoSink; import reactor.core.publisher.Sinks; +/** + * Represents a Model Control Protocol (MCP) session on the server side. It manages + * bidirectional JSON-RPC communication with the client. + */ public class McpServerSession implements McpSession { private static final Logger logger = LoggerFactory.getLogger(McpServerSession.class); @@ -49,6 +53,18 @@ public class McpServerSession implements McpSession { private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED); + /** + * Creates a new server session with the given parameters and the transport to use. + * @param id session id + * @param transport the transport to use + * @param initHandler called when a + * {@link io.modelcontextprotocol.spec.McpSchema.InitializeRequest} is received by the + * server + * @param initNotificationHandler called when a + * {@link McpSchema.METHOD_NOTIFICATION_INITIALIZED} is received. + * @param requestHandlers map of request handlers to use + * @param notificationHandlers map of notification handlers to use + */ public McpServerSession(String id, McpServerTransport transport, InitRequestHandler initHandler, InitNotificationHandler initNotificationHandler, Map> requestHandlers, Map notificationHandlers) { @@ -60,10 +76,24 @@ public McpServerSession(String id, McpServerTransport transport, InitRequestHand this.notificationHandlers = notificationHandlers; } + /** + * Retrieve the session id. + * @return session id + */ public String getId() { return this.id; } + /** + * Called upon successful initialization sequence between the client and the server + * with the client capabilities and information. + * + * Initialization + * Spec + * @param clientCapabilities the capabilities the connected client provides + * @param clientInfo the information about the connected client + */ public void init(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { this.clientCapabilities.lazySet(clientCapabilities); this.clientInfo.lazySet(clientInfo); @@ -73,6 +103,7 @@ private String generateRequestId() { return this.id + "-" + this.requestCounter.getAndIncrement(); } + @Override public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { String requestId = this.generateRequestId(); @@ -107,6 +138,16 @@ public Mono sendNotification(String method, Map params) { return this.transport.sendMessage(jsonrpcNotification); } + /** + * Called by the {@link McpServerTransportProvider} once the session is determined. + * The purpose of this method is to dispatch the message to an appropriate handler as + * specified by the MCP server implementation + * ({@link io.modelcontextprotocol.server.McpAsyncServer} or + * {@link io.modelcontextprotocol.server.McpSyncServer}) via + * {@link McpServerSession.Factory} that the server creates. + * @param message the incoming JSON-RPC message + * @return a Mono that completes when the message is processed + */ public Mono handle(McpSchema.JSONRPCMessage message) { return Mono.defer(() -> { // TODO handle errors for communication to without initialization happening @@ -232,33 +273,80 @@ public void close() { this.transport.close(); } + /** + * Request handler for the initialization request. + */ public interface InitRequestHandler { + /** + * Handles the initialization request. + * @param initializeRequest the initialization request by the client + * @return a Mono that will emit the result of the initialization + */ Mono handle(McpSchema.InitializeRequest initializeRequest); } + /** + * Notification handler for the initialization notification from the client. + */ public interface InitNotificationHandler { + /** + * Specifies an action to take upon successful initialization. + * @return a Mono that will complete when the initialization is acted upon. + */ Mono handle(); } + /** + * A handler for client-initiated notifications. + */ public interface NotificationHandler { + /** + * Handles a notification from the client. + * @param exchange the exchange associated with the client that allows calling + * back to the connected client or inspecting its capabilities. + * @param params the parameters of the notification. + * @return a Mono that completes once the notification is handled. + */ Mono handle(McpAsyncServerExchange exchange, Object params); } + /** + * A handler for client-initiated requests. + * + * @param the type of the response that is expected as a result of handling the + * request. + */ public interface RequestHandler { + /** + * Handles a request from the client. + * @param exchange the exchange associated with the client that allows calling + * back to the connected client or inspecting its capabilities. + * @param params the parameters of the request. + * @return a Mono that will emit the response to the request. + */ Mono handle(McpAsyncServerExchange exchange, Object params); } + /** + * Factory for creating server sessions which delegate to a provided 1:1 transport + * with a connected client. + */ @FunctionalInterface public interface Factory { + /** + * Creates a new 1:1 representation of the client-server interaction. + * @param sessionTransport the transport to use for communication with the client. + * @return a new server session. + */ McpServerSession create(McpServerTransport sessionTransport); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java index ef5f5c6f..632b8cee 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java @@ -1,5 +1,11 @@ package io.modelcontextprotocol.spec; +/** + * Marker interface for the server-side MCP transport. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ public interface McpServerTransport extends McpTransport { } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java index 41b07fdb..dba8cc43 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java @@ -4,29 +4,62 @@ import reactor.core.publisher.Mono; +/** + * The core building block providing the server-side MCP transport. Implement this + * interface to bridge between a particular server-side technology and the MCP server + * transport layer. + * + *

    + * The lifecycle of the provider dictates that it be created first, upon application + * startup, and then passed into either + * {@link io.modelcontextprotocol.server.McpServer#sync(McpServerTransportProvider)} or + * {@link io.modelcontextprotocol.server.McpServer#async(McpServerTransportProvider)}. As + * a result of the MCP server creation, the provider will be notified of a + * {@link McpServerSession.Factory} which will be used to handle a 1:1 communication + * between a newly connected client and the server. The provider's responsibility is to + * create instances of {@link McpServerTransport} that the session will utilise during the + * session lifetime. + * + *

    + * Finally, the {@link McpServerTransport}s can be closed in bulk when {@link #close()} or + * {@link #closeGracefully()} are called as part of the normal application shutdown event. + * Individual {@link McpServerTransport}s can also be closed on a per-session basis, where + * the {@link McpServerSession#close()} or {@link McpServerSession#closeGracefully()} + * closes the provided transport. + * + * @author Dariusz Jędrzejczyk + */ public interface McpServerTransportProvider { - // TODO: Consider adding a ProviderFactory that gets the Session Factory + /** + * Sets the session factory that will be used to create sessions for new clients. An + * implementation of the MCP server MUST call this method before any MCP interactions + * take place. + * @param sessionFactory the session factory to be used for initiating client sessions + */ void setSessionFactory(McpServerSession.Factory sessionFactory); + /** + * Sends a notification to all connected clients. + * @param method the name of the notification method to be called on the clients + * @param params a map of parameters to be sent with the notification + * @return a Mono that completes when the notification has been broadcast + * @see McpSession#sendNotification(String, Map) + */ Mono notifyClients(String method, Map params); /** - * Closes the transport connection and releases any associated resources. - * - *

    - * This method ensures proper cleanup of resources when the transport is no longer - * needed. It should handle the graceful shutdown of any active connections. - *

    + * Immediately closes all the transports with connected clients and releases any + * associated resources. */ default void close() { this.closeGracefully().subscribe(); } /** - * Closes the transport connection and releases any associated resources - * asynchronously. - * @return a {@link Mono} that completes when the connection has been closed. + * Gracefully closes all the transports with connected clients and releases any + * associated resources asynchronously. + * @return a {@link Mono} that completes when the connections have been closed. */ Mono closeGracefully(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java index 92b46075..b97c3ccc 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java @@ -26,14 +26,15 @@ public interface McpSession { /** - * Sends a request to the model server and expects a response of type T. + * Sends a request to the model counterparty and expects a response of type T. * *

    * This method handles the request-response pattern where a response is expected from - * the server. The response type is determined by the provided TypeReference. + * the client or server. The response type is determined by the provided + * TypeReference. *

    * @param the type of the expected response - * @param method the name of the method to be called on the server + * @param method the name of the method to be called on the counterparty * @param requestParams the parameters to be sent with the request * @param typeRef the TypeReference describing the expected response type * @return a Mono that will emit the response when received @@ -41,11 +42,11 @@ public interface McpSession { Mono sendRequest(String method, Object requestParams, TypeReference typeRef); /** - * Sends a notification to the model server without parameters. + * Sends a notification to the model client or server without parameters. * *

    * This method implements the notification pattern where no response is expected from - * the server. It's useful for fire-and-forget scenarios. + * the counterparty. It's useful for fire-and-forget scenarios. *

    * @param method the name of the notification method to be called on the server * @return a Mono that completes when the notification has been sent @@ -55,13 +56,13 @@ default Mono sendNotification(String method) { } /** - * Sends a notification to the model server with parameters. + * Sends a notification to the model client or server with parameters. * *

    * Similar to {@link #sendNotification(String)} but allows sending additional * parameters with the notification. *

    - * @param method the name of the notification method to be called on the server + * @param method the name of the notification method to be sent to the counterparty * @param params a map of parameters to be sent with the notification * @return a Mono that completes when the notification has been sent */ From 3daab08fa8270fb0a1b0a02c18e64ac0063b1d14 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Thu, 20 Mar 2025 12:01:50 +0100 Subject: [PATCH 18/20] refactor: replace ClientMcpTransport to McpClientTransport and DefaultMcpSession to McpClientSession - Replace the ClientMcpTransport interface with McpClientTransport and deprecate former - Replace the DefaultMcpSession with McpClientSession and deprecate DefaultMcpSession - Change all client transport implementation to implement the McpClientTransport instead of ClientMcpTransport - The McpMockTransport implements McpClientTransport instead of ClientMcpTransport Signed-off-by: Christian Tzolov --- .../transport/WebFluxSseClientTransport.java | 4 +- .../client/WebFluxSseMcpAsyncClientTests.java | 4 +- .../client/WebFluxSseMcpSyncClientTests.java | 4 +- .../MockMcpTransport.java | 8 +- .../client/AbstractMcpAsyncClientTests.java | 12 +- .../client/AbstractMcpSyncClientTests.java | 12 +- .../client/McpAsyncClient.java | 12 +- .../client/McpClient.java | 41 +++ .../client/McpSyncClient.java | 5 +- .../HttpClientSseClientTransport.java | 6 +- .../transport/StdioClientTransport.java | 5 +- .../server/McpAsyncServer.java | 32 +- .../transport/StdioServerTransport.java | 2 +- .../spec/DefaultMcpSession.java | 4 +- .../spec/McpClientSession.java | 288 ++++++++++++++++++ .../MockMcpTransport.java | 6 +- .../client/AbstractMcpAsyncClientTests.java | 12 +- .../client/AbstractMcpSyncClientTests.java | 12 +- .../client/HttpSseMcpAsyncClientTests.java | 6 +- .../client/HttpSseMcpSyncClientTests.java | 6 +- .../client/StdioMcpAsyncClientTests.java | 4 +- .../client/StdioMcpSyncClientTests.java | 6 +- ...nTests.java => McpClientSessionTests.java} | 20 +- 23 files changed, 419 insertions(+), 92 deletions(-) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java rename mcp/src/test/java/io/modelcontextprotocol/spec/{DefaultMcpSessionTests.java => McpClientSessionTests.java} (90%) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java index 8ea65fd7..b0dfa89c 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java @@ -9,7 +9,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; @@ -58,7 +58,7 @@ * "https://spec.modelcontextprotocol.io/specification/basic/transports/#http-with-sse">MCP * HTTP with SSE Transport Specification */ -public class WebFluxSseClientTransport implements ClientMcpTransport { +public class WebFluxSseClientTransport implements McpClientTransport { private static final Logger logger = LoggerFactory.getLogger(WebFluxSseClientTransport.class); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java index 0dccb27a..2dd587d4 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java @@ -7,7 +7,7 @@ import java.time.Duration; import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; @@ -32,7 +32,7 @@ class WebFluxSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected ClientMcpTransport createMcpTransport() { + protected McpClientTransport createMcpTransport() { return new WebFluxSseClientTransport(WebClient.builder().baseUrl(host)); } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java index f5cab7b7..72b390dd 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java @@ -7,7 +7,7 @@ import java.time.Duration; import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; @@ -32,7 +32,7 @@ class WebFluxSseMcpSyncClientTests extends AbstractMcpSyncClientTests { .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected ClientMcpTransport createMcpTransport() { + protected McpClientTransport createMcpTransport() { return new WebFluxSseClientTransport(WebClient.builder().baseUrl(host)); } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java b/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java index d4e48ea7..cef3fb9f 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java @@ -11,19 +11,19 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.ServerMcpTransport; import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; +import io.modelcontextprotocol.spec.ServerMcpTransport; import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; /** - * A mock implementation of the {@link ClientMcpTransport} and {@link ServerMcpTransport} + * A mock implementation of the {@link McpClientTransport} and {@link ServerMcpTransport} * interfaces. */ -public class MockMcpTransport implements ClientMcpTransport, ServerMcpTransport { +public class MockMcpTransport implements McpClientTransport, ServerMcpTransport { private final Sinks.Many inbound = Sinks.many().unicast().onBackpressureBuffer(); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 02aa23d8..71356351 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -12,7 +12,7 @@ import java.util.function.Consumer; import java.util.function.Function; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; @@ -49,7 +49,7 @@ public abstract class AbstractMcpAsyncClientTests { private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!"; - abstract protected ClientMcpTransport createMcpTransport(); + abstract protected McpClientTransport createMcpTransport(); protected void onStart() { } @@ -65,11 +65,11 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } - McpAsyncClient client(ClientMcpTransport transport) { + McpAsyncClient client(McpClientTransport transport) { return client(transport, Function.identity()); } - McpAsyncClient client(ClientMcpTransport transport, Function customizer) { + McpAsyncClient client(McpClientTransport transport, Function customizer) { AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { @@ -84,11 +84,11 @@ McpAsyncClient client(ClientMcpTransport transport, Function c) { + void withClient(McpClientTransport transport, Consumer c) { withClient(transport, Function.identity(), c); } - void withClient(ClientMcpTransport transport, Function customizer, + void withClient(McpClientTransport transport, Function customizer, Consumer c) { var client = client(transport, customizer); try { diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 191de23b..128441f8 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -11,7 +11,7 @@ import java.util.function.Consumer; import java.util.function.Function; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; @@ -49,7 +49,7 @@ public abstract class AbstractMcpSyncClientTests { private static final String TEST_MESSAGE = "Hello MCP Spring AI!"; - abstract protected ClientMcpTransport createMcpTransport(); + abstract protected McpClientTransport createMcpTransport(); protected void onStart() { } @@ -65,11 +65,11 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } - McpSyncClient client(ClientMcpTransport transport) { + McpSyncClient client(McpClientTransport transport) { return client(transport, Function.identity()); } - McpSyncClient client(ClientMcpTransport transport, Function customizer) { + McpSyncClient client(McpClientTransport transport, Function customizer) { AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { @@ -84,11 +84,11 @@ McpSyncClient client(ClientMcpTransport transport, Function c) { + void withClient(McpClientTransport transport, Consumer c) { withClient(transport, Function.identity(), c); } - void withClient(ClientMcpTransport transport, Function customizer, + void withClient(McpClientTransport transport, Function customizer, Consumer c) { var client = client(transport, customizer); try { diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index 278e360d..9cbef050 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -15,9 +15,9 @@ import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.spec.ClientMcpTransport; -import io.modelcontextprotocol.spec.DefaultMcpSession; -import io.modelcontextprotocol.spec.DefaultMcpSession.NotificationHandler; -import io.modelcontextprotocol.spec.DefaultMcpSession.RequestHandler; +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpClientSession.NotificationHandler; +import io.modelcontextprotocol.spec.McpClientSession.RequestHandler; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; @@ -73,7 +73,7 @@ * @author Christian Tzolov * @see McpClient * @see McpSchema - * @see DefaultMcpSession + * @see McpClientSession */ public class McpAsyncClient { @@ -95,7 +95,7 @@ public class McpAsyncClient { * The MCP session implementation that manages bidirectional JSON-RPC communication * between clients and servers. */ - private final DefaultMcpSession mcpSession; + private final McpClientSession mcpSession; /** * Client capabilities. @@ -228,7 +228,7 @@ public class McpAsyncClient { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE, asyncLoggingNotificationHandler(loggingConsumersFinal)); - this.mcpSession = new DefaultMcpSession(requestTimeout, transport, requestHandlers, notificationHandlers); + this.mcpSession = new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java index fa2690dc..9c5f7b01 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -13,6 +13,7 @@ import java.util.function.Function; import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpTransport; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; @@ -113,11 +114,31 @@ public interface McpClient { * and {@code SseClientTransport} for SSE-based communication. * @return A new builder instance for configuring the client * @throws IllegalArgumentException if transport is null + * @deprecated This method will be removed in 0.9.0. Use + * {@link #sync(McpClientTransport)} */ + @Deprecated static SyncSpec sync(ClientMcpTransport transport) { return new SyncSpec(transport); } + /** + * Start building a synchronous MCP client with the specified transport layer. The + * synchronous MCP client provides blocking operations. Synchronous clients wait for + * each operation to complete before returning, making them simpler to use but + * potentially less performant for concurrent operations. The transport layer handles + * the low-level communication between client and server using protocols like stdio or + * Server-Sent Events (SSE). + * @param transport The transport layer implementation for MCP communication. Common + * implementations include {@code StdioClientTransport} for stdio-based communication + * and {@code SseClientTransport} for SSE-based communication. + * @return A new builder instance for configuring the client + * @throws IllegalArgumentException if transport is null + */ + static SyncSpec sync(McpClientTransport transport) { + return new SyncSpec(transport); + } + /** * Start building an asynchronous MCP client with the specified transport layer. The * asynchronous MCP client provides non-blocking operations. Asynchronous clients @@ -130,11 +151,31 @@ static SyncSpec sync(ClientMcpTransport transport) { * and {@code SseClientTransport} for SSE-based communication. * @return A new builder instance for configuring the client * @throws IllegalArgumentException if transport is null + * @deprecated This method will be removed in 0.9.0. Use + * {@link #async(McpClientTransport)} */ + @Deprecated static AsyncSpec async(ClientMcpTransport transport) { return new AsyncSpec(transport); } + /** + * Start building an asynchronous MCP client with the specified transport layer. The + * asynchronous MCP client provides non-blocking operations. Asynchronous clients + * return reactive primitives (Mono/Flux) immediately, allowing for concurrent + * operations and reactive programming patterns. The transport layer handles the + * low-level communication between client and server using protocols like stdio or + * Server-Sent Events (SSE). + * @param transport The transport layer implementation for MCP communication. Common + * implementations include {@code StdioClientTransport} for stdio-based communication + * and {@code SseClientTransport} for SSE-based communication. + * @return A new builder instance for configuring the client + * @throws IllegalArgumentException if transport is null + */ + static AsyncSpec async(McpClientTransport transport) { + return new AsyncSpec(transport); + } + /** * Synchronous client specification. This class follows the builder pattern to provide * a fluent API for setting up clients with custom configurations. diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index e5d964b7..ec0a0dfd 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -6,7 +6,7 @@ import java.time.Duration; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; @@ -66,7 +66,8 @@ public class McpSyncClient implements AutoCloseable { * Create a new McpSyncClient with the given delegate. * @param delegate the asynchronous kernel on top of which this synchronous client * provides a blocking API. - * @deprecated Use {@link McpClient#sync(ClientMcpTransport)} to obtain an instance. + * @deprecated This method will be removed in 0.9.0. Use + * {@link McpClient#sync(McpClientTransport)} to obtain an instance. */ @Deprecated // TODO make the constructor package private post-deprecation diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 1e7df31a..ca1b0e87 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -18,7 +18,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.client.transport.FlowSseClient.SseEvent; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; @@ -52,9 +52,9 @@ * * @author Christian Tzolov * @see io.modelcontextprotocol.spec.McpTransport - * @see io.modelcontextprotocol.spec.ClientMcpTransport + * @see io.modelcontextprotocol.spec.McpClientTransport */ -public class HttpClientSseClientTransport implements ClientMcpTransport { +public class HttpClientSseClientTransport implements McpClientTransport { private static final Logger logger = LoggerFactory.getLogger(HttpClientSseClientTransport.class); diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java index d35db3f8..f9a97849 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java @@ -11,14 +11,13 @@ import java.time.Duration; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executors; import java.util.function.Consumer; import java.util.function.Function; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; import io.modelcontextprotocol.util.Assert; @@ -38,7 +37,7 @@ * @author Christian Tzolov * @author Dariusz Jędrzejczyk */ -public class StdioClientTransport implements ClientMcpTransport { +public class StdioClientTransport implements McpClientTransport { private static final Logger logger = LoggerFactory.getLogger(StdioClientTransport.class); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 44536816..07a9f154 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -17,7 +17,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.DefaultMcpSession; +import io.modelcontextprotocol.spec.McpClientSession; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerTransportProvider; @@ -73,7 +73,7 @@ * @author Dariusz Jędrzejczyk * @see McpServer * @see McpSchema - * @see DefaultMcpSession + * @see McpClientSession */ public class McpAsyncServer { @@ -876,7 +876,7 @@ private static final class LegacyAsyncServer extends McpAsyncServer { * The MCP session implementation that manages bidirectional JSON-RPC * communication between clients and servers. */ - private final DefaultMcpSession mcpSession; + private final McpClientSession mcpSession; private final ServerMcpTransport transport; @@ -920,7 +920,7 @@ private static final class LegacyAsyncServer extends McpAsyncServer { this.resourceTemplates.addAll(features.resourceTemplates()); this.prompts.putAll(features.prompts()); - Map> requestHandlers = new HashMap<>(); + Map> requestHandlers = new HashMap<>(); // Initialize request handlers for standard MCP methods requestHandlers.put(McpSchema.METHOD_INITIALIZE, asyncInitializeRequestHandler()); @@ -952,7 +952,7 @@ private static final class LegacyAsyncServer extends McpAsyncServer { requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); } - Map notificationHandlers = new HashMap<>(); + Map notificationHandlers = new HashMap<>(); notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (params) -> Mono.empty()); @@ -972,7 +972,7 @@ private static final class LegacyAsyncServer extends McpAsyncServer { asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); this.transport = mcpTransport; - this.mcpSession = new DefaultMcpSession(Duration.ofSeconds(10), mcpTransport, requestHandlers, + this.mcpSession = new McpClientSession(Duration.ofSeconds(10), mcpTransport, requestHandlers, notificationHandlers); } @@ -997,7 +997,7 @@ public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpe // --------------------------------------- // Lifecycle Management // --------------------------------------- - private DefaultMcpSession.RequestHandler asyncInitializeRequestHandler() { + private McpClientSession.RequestHandler asyncInitializeRequestHandler() { return params -> { McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(params, new TypeReference() { @@ -1100,7 +1100,7 @@ public Mono listRoots(String cursor) { LIST_ROOTS_RESULT_TYPE_REF); } - private DefaultMcpSession.NotificationHandler asyncRootsListChangedNotificationHandler( + private McpClientSession.NotificationHandler asyncRootsListChangedNotificationHandler( List, Mono>> rootsChangeConsumers) { return params -> listRoots().flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) .flatMap(consumer -> consumer.apply(listRootsResult.roots())) @@ -1187,7 +1187,7 @@ public Mono notifyToolsListChanged() { return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); } - private DefaultMcpSession.RequestHandler toolsListRequestHandler() { + private McpClientSession.RequestHandler toolsListRequestHandler() { return params -> { List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); @@ -1195,7 +1195,7 @@ private DefaultMcpSession.RequestHandler toolsListReq }; } - private DefaultMcpSession.RequestHandler toolsCallRequestHandler() { + private McpClientSession.RequestHandler toolsCallRequestHandler() { return params -> { McpSchema.CallToolRequest callToolRequest = transport.unmarshalFrom(params, new TypeReference() { @@ -1281,7 +1281,7 @@ public Mono notifyResourcesListChanged() { return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); } - private DefaultMcpSession.RequestHandler resourcesListRequestHandler() { + private McpClientSession.RequestHandler resourcesListRequestHandler() { return params -> { var resourceList = this.resources.values() .stream() @@ -1291,12 +1291,12 @@ private DefaultMcpSession.RequestHandler resource }; } - private DefaultMcpSession.RequestHandler resourceTemplateListRequestHandler() { + private McpClientSession.RequestHandler resourceTemplateListRequestHandler() { return params -> Mono.just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null)); } - private DefaultMcpSession.RequestHandler resourcesReadRequestHandler() { + private McpClientSession.RequestHandler resourcesReadRequestHandler() { return params -> { McpSchema.ReadResourceRequest resourceRequest = transport.unmarshalFrom(params, new TypeReference() { @@ -1385,7 +1385,7 @@ public Mono notifyPromptsListChanged() { return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); } - private DefaultMcpSession.RequestHandler promptsListRequestHandler() { + private McpClientSession.RequestHandler promptsListRequestHandler() { return params -> { // TODO: Implement pagination // McpSchema.PaginatedRequest request = transport.unmarshalFrom(params, @@ -1401,7 +1401,7 @@ private DefaultMcpSession.RequestHandler promptsLis }; } - private DefaultMcpSession.RequestHandler promptsGetRequestHandler() { + private McpClientSession.RequestHandler promptsGetRequestHandler() { return params -> { McpSchema.GetPromptRequest promptRequest = transport.unmarshalFrom(params, new TypeReference() { @@ -1449,7 +1449,7 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN * will not be sent. * @return A handler that processes logging level change requests */ - private DefaultMcpSession.RequestHandler setLoggerRequestHandler() { + private McpClientSession.RequestHandler setLoggerRequestHandler() { return params -> { this.minLoggingLevel = transport.unmarshalFrom(params, new TypeReference() { }); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransport.java index 14129c52..78264ca3 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransport.java @@ -33,7 +33,7 @@ * over stdin/stdout, with errors and debug information sent to stderr. * * @author Christian Tzolov - * @deprecated Use + * @deprecated This method will be removed in 0.9.0. Use * {@link io.modelcontextprotocol.server.transport.StdioServerTransportProvider} instead. */ @Deprecated diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java index add33d7a..83de4c09 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java @@ -34,8 +34,10 @@ * * @author Christian Tzolov * @author Dariusz Jędrzejczyk + * @deprecated This method will be removed in 0.9.0. Use {@link McpClientSession} instead */ -// TODO: DefaultMcpSession is only relevant to the client-side. +@Deprecated + public class DefaultMcpSession implements McpSession { /** Logger for this class */ diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java new file mode 100644 index 00000000..6657e362 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -0,0 +1,288 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.time.Duration; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoSink; + +/** + * Default implementation of the MCP (Model Context Protocol) session that manages + * bidirectional JSON-RPC communication between clients and servers. This implementation + * follows the MCP specification for message exchange and transport handling. + * + *

    + * The session manages: + *

      + *
    • Request/response handling with unique message IDs
    • + *
    • Notification processing
    • + *
    • Message timeout management
    • + *
    • Transport layer abstraction
    • + *
    + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +public class McpClientSession implements McpSession { + + /** Logger for this class */ + private static final Logger logger = LoggerFactory.getLogger(McpClientSession.class); + + /** Duration to wait for request responses before timing out */ + private final Duration requestTimeout; + + /** Transport layer implementation for message exchange */ + private final McpTransport transport; + + /** Map of pending responses keyed by request ID */ + private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); + + /** Map of request handlers keyed by method name */ + private final ConcurrentHashMap> requestHandlers = new ConcurrentHashMap<>(); + + /** Map of notification handlers keyed by method name */ + private final ConcurrentHashMap notificationHandlers = new ConcurrentHashMap<>(); + + /** Session-specific prefix for request IDs */ + private final String sessionPrefix = UUID.randomUUID().toString().substring(0, 8); + + /** Atomic counter for generating unique request IDs */ + private final AtomicLong requestCounter = new AtomicLong(0); + + private final Disposable connection; + + /** + * Functional interface for handling incoming JSON-RPC requests. Implementations + * should process the request parameters and return a response. + * + * @param Response type + */ + @FunctionalInterface + public interface RequestHandler { + + /** + * Handles an incoming request with the given parameters. + * @param params The request parameters + * @return A Mono containing the response object + */ + Mono handle(Object params); + + } + + /** + * Functional interface for handling incoming JSON-RPC notifications. Implementations + * should process the notification parameters without returning a response. + */ + @FunctionalInterface + public interface NotificationHandler { + + /** + * Handles an incoming notification with the given parameters. + * @param params The notification parameters + * @return A Mono that completes when the notification is processed + */ + Mono handle(Object params); + + } + + /** + * Creates a new McpClientSession with the specified configuration and handlers. + * @param requestTimeout Duration to wait for responses + * @param transport Transport implementation for message exchange + * @param requestHandlers Map of method names to request handlers + * @param notificationHandlers Map of method names to notification handlers + */ + public McpClientSession(Duration requestTimeout, McpTransport transport, + Map> requestHandlers, Map notificationHandlers) { + + Assert.notNull(requestTimeout, "The requstTimeout can not be null"); + Assert.notNull(transport, "The transport can not be null"); + Assert.notNull(requestHandlers, "The requestHandlers can not be null"); + Assert.notNull(notificationHandlers, "The notificationHandlers can not be null"); + + this.requestTimeout = requestTimeout; + this.transport = transport; + this.requestHandlers.putAll(requestHandlers); + this.notificationHandlers.putAll(notificationHandlers); + + // TODO: consider mono.transformDeferredContextual where the Context contains + // the + // Observation associated with the individual message - it can be used to + // create child Observation and emit it together with the message to the + // consumer + this.connection = this.transport.connect(mono -> mono.doOnNext(message -> { + if (message instanceof McpSchema.JSONRPCResponse response) { + logger.debug("Received Response: {}", response); + var sink = pendingResponses.remove(response.id()); + if (sink == null) { + logger.warn("Unexpected response for unkown id {}", response.id()); + } + else { + sink.success(response); + } + } + else if (message instanceof McpSchema.JSONRPCRequest request) { + logger.debug("Received request: {}", request); + handleIncomingRequest(request).subscribe(response -> transport.sendMessage(response).subscribe(), + error -> { + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), + null, new McpSchema.JSONRPCResponse.JSONRPCError( + McpSchema.ErrorCodes.INTERNAL_ERROR, error.getMessage(), null)); + transport.sendMessage(errorResponse).subscribe(); + }); + } + else if (message instanceof McpSchema.JSONRPCNotification notification) { + logger.debug("Received notification: {}", notification); + handleIncomingNotification(notification).subscribe(null, + error -> logger.error("Error handling notification: {}", error.getMessage())); + } + })).subscribe(); + } + + /** + * Handles an incoming JSON-RPC request by routing it to the appropriate handler. + * @param request The incoming JSON-RPC request + * @return A Mono containing the JSON-RPC response + */ + private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request) { + return Mono.defer(() -> { + var handler = this.requestHandlers.get(request.method()); + if (handler == null) { + MethodNotFoundError error = getMethodNotFoundError(request.method()); + return Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, + error.message(), error.data()))); + } + + return handler.handle(request.params()) + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) + .onErrorResume(error -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), + null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), null)))); // TODO: add error message + // through the data field + }); + } + + record MethodNotFoundError(String method, String message, Object data) { + } + + public static MethodNotFoundError getMethodNotFoundError(String method) { + switch (method) { + case McpSchema.METHOD_ROOTS_LIST: + return new MethodNotFoundError(method, "Roots not supported", + Map.of("reason", "Client does not have roots capability")); + default: + return new MethodNotFoundError(method, "Method not found: " + method, null); + } + } + + /** + * Handles an incoming JSON-RPC notification by routing it to the appropriate handler. + * @param notification The incoming JSON-RPC notification + * @return A Mono that completes when the notification is processed + */ + private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification) { + return Mono.defer(() -> { + var handler = notificationHandlers.get(notification.method()); + if (handler == null) { + logger.error("No handler registered for notification method: {}", notification.method()); + return Mono.empty(); + } + return handler.handle(notification.params()); + }); + } + + /** + * Generates a unique request ID in a non-blocking way. Combines a session-specific + * prefix with an atomic counter to ensure uniqueness. + * @return A unique request ID string + */ + private String generateRequestId() { + return this.sessionPrefix + "-" + this.requestCounter.getAndIncrement(); + } + + /** + * Sends a JSON-RPC request and returns the response. + * @param The expected response type + * @param method The method name to call + * @param requestParams The request parameters + * @param typeRef Type reference for response deserialization + * @return A Mono containing the response + */ + @Override + public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + String requestId = this.generateRequestId(); + + return Mono.create(sink -> { + this.pendingResponses.put(requestId, sink); + McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, + requestId, requestParams); + this.transport.sendMessage(jsonrpcRequest) + // TODO: It's most efficient to create a dedicated Subscriber here + .subscribe(v -> { + }, error -> { + this.pendingResponses.remove(requestId); + sink.error(error); + }); + }).timeout(this.requestTimeout).handle((jsonRpcResponse, sink) -> { + if (jsonRpcResponse.error() != null) { + sink.error(new McpError(jsonRpcResponse.error())); + } + else { + if (typeRef.getType().equals(Void.class)) { + sink.complete(); + } + else { + sink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef)); + } + } + }); + } + + /** + * Sends a JSON-RPC notification. + * @param method The method name for the notification + * @param params The notification parameters + * @return A Mono that completes when the notification is sent + */ + @Override + public Mono sendNotification(String method, Map params) { + McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + method, params); + return this.transport.sendMessage(jsonrpcNotification); + } + + /** + * Closes the session gracefully, allowing pending operations to complete. + * @return A Mono that completes when the session is closed + */ + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + this.connection.dispose(); + return transport.closeGracefully(); + }); + } + + /** + * Closes the session immediately, potentially interrupting pending operations. + */ + @Override + public void close() { + this.connection.dispose(); + transport.close(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpTransport.java b/mcp/src/test/java/io/modelcontextprotocol/MockMcpTransport.java index d4e48ea7..12f30d12 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpTransport.java +++ b/mcp/src/test/java/io/modelcontextprotocol/MockMcpTransport.java @@ -11,7 +11,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.ServerMcpTransport; import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification; @@ -20,10 +20,10 @@ import reactor.core.publisher.Sinks; /** - * A mock implementation of the {@link ClientMcpTransport} and {@link ServerMcpTransport} + * A mock implementation of the {@link McpClientTransport} and {@link ServerMcpTransport} * interfaces. */ -public class MockMcpTransport implements ClientMcpTransport, ServerMcpTransport { +public class MockMcpTransport implements McpClientTransport, ServerMcpTransport { private final Sinks.Many inbound = Sinks.many().unicast().onBackpressureBuffer(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index f7a0a492..ac7b9e5e 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -12,7 +12,7 @@ import java.util.function.Consumer; import java.util.function.Function; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; @@ -50,7 +50,7 @@ public abstract class AbstractMcpAsyncClientTests { private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!"; - abstract protected ClientMcpTransport createMcpTransport(); + abstract protected McpClientTransport createMcpTransport(); protected void onStart() { } @@ -66,11 +66,11 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } - McpAsyncClient client(ClientMcpTransport transport) { + McpAsyncClient client(McpClientTransport transport) { return client(transport, Function.identity()); } - McpAsyncClient client(ClientMcpTransport transport, Function customizer) { + McpAsyncClient client(McpClientTransport transport, Function customizer) { AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { @@ -85,11 +85,11 @@ McpAsyncClient client(ClientMcpTransport transport, Function c) { + void withClient(McpClientTransport transport, Consumer c) { withClient(transport, Function.identity(), c); } - void withClient(ClientMcpTransport transport, Function customizer, + void withClient(McpClientTransport transport, Function customizer, Consumer c) { var client = client(transport, customizer); try { diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index f4d8dbdb..24c161eb 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -11,7 +11,7 @@ import java.util.function.Consumer; import java.util.function.Function; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; @@ -50,7 +50,7 @@ public abstract class AbstractMcpSyncClientTests { private static final String TEST_MESSAGE = "Hello MCP Spring AI!"; - abstract protected ClientMcpTransport createMcpTransport(); + abstract protected McpClientTransport createMcpTransport(); protected void onStart() { } @@ -66,11 +66,11 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } - McpSyncClient client(ClientMcpTransport transport) { + McpSyncClient client(McpClientTransport transport) { return client(transport, Function.identity()); } - McpSyncClient client(ClientMcpTransport transport, Function customizer) { + McpSyncClient client(McpClientTransport transport, Function customizer) { AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { @@ -85,11 +85,11 @@ McpSyncClient client(ClientMcpTransport transport, Function c) { + void withClient(McpClientTransport transport, Consumer c) { withClient(transport, Function.identity(), c); } - void withClient(ClientMcpTransport transport, Function customizer, + void withClient(McpClientTransport transport, Function customizer, Consumer c) { var client = client(transport, customizer); try { diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java index ac0fef24..15749d4f 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java @@ -4,10 +4,8 @@ package io.modelcontextprotocol.client; -import java.time.Duration; - import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; @@ -30,7 +28,7 @@ class HttpSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected ClientMcpTransport createMcpTransport() { + protected McpClientTransport createMcpTransport() { return new HttpClientSseClientTransport(host); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java index 8772e620..067f9295 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java @@ -4,10 +4,8 @@ package io.modelcontextprotocol.client; -import java.time.Duration; - import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; @@ -30,7 +28,7 @@ class HttpSseMcpSyncClientTests extends AbstractMcpSyncClientTests { .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected ClientMcpTransport createMcpTransport() { + protected McpClientTransport createMcpTransport() { return new HttpClientSseClientTransport(host); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java index c285e2c6..95230942 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java @@ -8,7 +8,7 @@ import io.modelcontextprotocol.client.transport.ServerParameters; import io.modelcontextprotocol.client.transport.StdioClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; /** @@ -21,7 +21,7 @@ class StdioMcpAsyncClientTests extends AbstractMcpAsyncClientTests { @Override - protected ClientMcpTransport createMcpTransport() { + protected McpClientTransport createMcpTransport() { ServerParameters stdioParams = ServerParameters.builder("npx") .args("-y", "@modelcontextprotocol/server-everything", "dir") .build(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java index ebf10b9a..925852b5 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java @@ -11,7 +11,7 @@ import io.modelcontextprotocol.client.transport.ServerParameters; import io.modelcontextprotocol.client.transport.StdioClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import reactor.core.publisher.Sinks; @@ -29,7 +29,7 @@ class StdioMcpSyncClientTests extends AbstractMcpSyncClientTests { @Override - protected ClientMcpTransport createMcpTransport() { + protected McpClientTransport createMcpTransport() { ServerParameters stdioParams = ServerParameters.builder("npx") .args("-y", "@modelcontextprotocol/server-everything", "dir") .build(); @@ -42,7 +42,7 @@ void customErrorHandlerShouldReceiveErrors() throws InterruptedException { CountDownLatch latch = new CountDownLatch(1); AtomicReference receivedError = new AtomicReference<>(); - ClientMcpTransport transport = createMcpTransport(); + McpClientTransport transport = createMcpTransport(); StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); ((StdioClientTransport) transport).setStdErrorHandler(error -> { diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/DefaultMcpSessionTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java similarity index 90% rename from mcp/src/test/java/io/modelcontextprotocol/spec/DefaultMcpSessionTests.java rename to mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java index 9d011aff..79a1d0d9 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/DefaultMcpSessionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java @@ -22,14 +22,14 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; /** - * Test suite for {@link DefaultMcpSession} that verifies its JSON-RPC message handling, + * Test suite for {@link McpClientSession} that verifies its JSON-RPC message handling, * request-response correlation, and notification processing. * * @author Christian Tzolov */ -class DefaultMcpSessionTests { +class McpClientSessionTests { - private static final Logger logger = LoggerFactory.getLogger(DefaultMcpSessionTests.class); + private static final Logger logger = LoggerFactory.getLogger(McpClientSessionTests.class); private static final Duration TIMEOUT = Duration.ofSeconds(5); @@ -39,14 +39,14 @@ class DefaultMcpSessionTests { private static final String ECHO_METHOD = "echo"; - private DefaultMcpSession session; + private McpClientSession session; private MockMcpTransport transport; @BeforeEach void setUp() { transport = new MockMcpTransport(); - session = new DefaultMcpSession(TIMEOUT, transport, Map.of(), + session = new McpClientSession(TIMEOUT, transport, Map.of(), Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: " + params)))); } @@ -59,11 +59,11 @@ void tearDown() { @Test void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> new DefaultMcpSession(null, transport, Map.of(), Map.of())) + assertThatThrownBy(() -> new McpClientSession(null, transport, Map.of(), Map.of())) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("requstTimeout can not be null"); - assertThatThrownBy(() -> new DefaultMcpSession(TIMEOUT, null, Map.of(), Map.of())) + assertThatThrownBy(() -> new McpClientSession(TIMEOUT, null, Map.of(), Map.of())) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("transport can not be null"); } @@ -137,10 +137,10 @@ void testSendNotification() { @Test void testRequestHandling() { String echoMessage = "Hello MCP!"; - Map> requestHandlers = Map.of(ECHO_METHOD, + Map> requestHandlers = Map.of(ECHO_METHOD, params -> Mono.just(params)); transport = new MockMcpTransport(); - session = new DefaultMcpSession(TIMEOUT, transport, requestHandlers, Map.of()); + session = new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of()); // Simulate incoming request McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD, @@ -160,7 +160,7 @@ void testNotificationHandling() { Sinks.One receivedParams = Sinks.one(); transport = new MockMcpTransport(); - session = new DefaultMcpSession(TIMEOUT, transport, Map.of(), + session = new McpClientSession(TIMEOUT, transport, Map.of(), Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> receivedParams.tryEmitValue(params)))); // Simulate incoming notification from the server From ac6cafbd987451cbce0659de50c6789aabaf0873 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Thu, 20 Mar 2025 14:17:48 +0100 Subject: [PATCH 19/20] Add migration guide markdown Signed-off-by: Christian Tzolov --- migration-0.8.0.md | 326 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 326 insertions(+) create mode 100644 migration-0.8.0.md diff --git a/migration-0.8.0.md b/migration-0.8.0.md new file mode 100644 index 00000000..2e0b859e --- /dev/null +++ b/migration-0.8.0.md @@ -0,0 +1,326 @@ +# MCP Java SDK Migration Guide: 0.7.0 to 0.8.0 + +This document outlines the breaking changes and provides guidance on how to migrate your code from version 0.7.0 to 0.8.0. + +The 0.8.0 refactoring introduces a robust session-based architecture for server-side MCP implementations, to improve the SDK's ability to handle multiple concurrent client connections and provide a more consistent API. The main changes include: + +1. Introduction of a session-based architecture +2. New transport provider abstraction +3. Exchange objects for client interaction +4. Renamed and reorganized interfaces +5. Updated handler signatures + +## Breaking Changes + +### 1. Interface Renaming + +Several interfaces have been renamed to better reflect their roles: + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `ClientMcpTransport` | `McpClientTransport` | +| `ServerMcpTransport` | `McpServerTransport` | +| `DefaultMcpSession` | `McpClientSession`, `McpServerSession` | + +### 2. New Server Transport Architecture + +The most significant change is the introduction of the `McpServerTransportProvider` interface, which replaces direct usage of `ServerMcpTransport` when creating servers. This new pattern separates the concerns of: + +1. **Transport Provider**: Manages connections with clients and creates individual transports for each connection +2. **Server Transport**: Handles communication with a specific client connection + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `ServerMcpTransport` | `McpServerTransportProvider` + `McpServerTransport` | +| Direct transport usage | Session-based transport usage | + +#### Before (0.7.0): + +```java +// Create a transport +ServerMcpTransport transport = new WebFluxSseServerTransport(objectMapper, "/mcp/message"); + +// Create a server with the transport +McpServer.sync(transport) + .serverInfo("my-server", "1.0.0") + .build(); +``` + +#### After (0.8.0): + +```java +// Create a transport provider +McpServerTransportProvider transportProvider = new WebFluxSseServerTransportProvider(objectMapper, "/mcp/message"); + +// Create a server with the transport provider +McpServer.sync(transportProvider) + .serverInfo("my-server", "1.0.0") + .build(); +``` + +### 3. Handler Method Signature Changes + +Tool, resource, and prompt handlers now receive an additional `exchange` parameter that provides access to client capabilities and methods to interact with the client: + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `(args) -> result` | `(exchange, args) -> result` | + +The exchange objects (`McpAsyncServerExchange` and `McpSyncServerExchange`) provide context for the current session and access to session-specific operations. + +#### Before (0.7.0): + +```java +// Tool handler +.tool(calculatorTool, args -> new CallToolResult("Result: " + calculate(args))) + +// Resource handler +.resource(fileResource, req -> new ReadResourceResult(readFile(req))) + +// Prompt handler +.prompt(analysisPrompt, req -> new GetPromptResult("Analysis prompt")) +``` + +#### After (0.8.0): + +```java +// Tool handler +.tool(calculatorTool, (exchange, args) -> new CallToolResult("Result: " + calculate(args))) + +// Resource handler +.resource(fileResource, (exchange, req) -> new ReadResourceResult(readFile(req))) + +// Prompt handler +.prompt(analysisPrompt, (exchange, req) -> new GetPromptResult("Analysis prompt")) +``` + +### 4. Registration vs. Specification + +The naming convention for handlers has changed from "Registration" to "Specification": + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `AsyncToolRegistration` | `AsyncToolSpecification` | +| `SyncToolRegistration` | `SyncToolSpecification` | +| `AsyncResourceRegistration` | `AsyncResourceSpecification` | +| `SyncResourceRegistration` | `SyncResourceSpecification` | +| `AsyncPromptRegistration` | `AsyncPromptSpecification` | +| `SyncPromptRegistration` | `SyncPromptSpecification` | + +### 5. Roots Change Handler Updates + +The roots change handlers now receive an exchange parameter: + +#### Before (0.7.0): + +```java +.rootsChangeConsumers(List.of( + roots -> { + // Process roots + } +)) +``` + +#### After (0.8.0): + +```java +.rootsChangeHandlers(List.of( + (exchange, roots) -> { + // Process roots with access to exchange + } +)) +``` + +### 6. Server Creation Method Changes + +The `McpServer` factory methods now accept `McpServerTransportProvider` instead of `ServerMcpTransport`: + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `McpServer.async(ServerMcpTransport)` | `McpServer.async(McpServerTransportProvider)` | +| `McpServer.sync(ServerMcpTransport)` | `McpServer.sync(McpServerTransportProvider)` | + +The method names for creating servers have been updated: + +Root change handlers now receive an exchange object: + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `rootsChangeConsumers(List>>)` | `rootsChangeHandlers(List>>)` | +| `rootsChangeConsumer(Consumer>)` | `rootsChangeHandler(BiConsumer>)` | + +### 7. Direct Server Methods Moving to Exchange + +Several methods that were previously available directly on the server are now accessed through the exchange object: + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `server.listRoots()` | `exchange.listRoots()` | +| `server.createMessage()` | `exchange.createMessage()` | +| `server.getClientCapabilities()` | `exchange.getClientCapabilities()` | +| `server.getClientInfo()` | `exchange.getClientInfo()` | + +The direct methods are deprecated and will be removed in 0.9.0: + +- `McpSyncServer.listRoots()` +- `McpSyncServer.getClientCapabilities()` +- `McpSyncServer.getClientInfo()` +- `McpSyncServer.createMessage()` +- `McpAsyncServer.listRoots()` +- `McpAsyncServer.getClientCapabilities()` +- `McpAsyncServer.getClientInfo()` +- `McpAsyncServer.createMessage()` + +## Deprecation Notices + +The following components are deprecated in 0.8.0 and will be removed in 0.9.0: + +- `ClientMcpTransport` interface (use `McpClientTransport` instead) +- `ServerMcpTransport` interface (use `McpServerTransport` instead) +- `DefaultMcpSession` class (use `McpClientSession` instead) +- `WebFluxSseServerTransport` class (use `WebFluxSseServerTransportProvider` instead) +- `WebMvcSseServerTransport` class (use `WebMvcSseServerTransportProvider` instead) +- `StdioServerTransport` class (use `StdioServerTransportProvider` instead) +- All `*Registration` classes (use corresponding `*Specification` classes instead) +- Direct server methods for client interaction (use exchange object instead) + +## Migration Examples + +### Example 1: Creating a Server + +#### Before (0.7.0): + +```java +// Create a transport +ServerMcpTransport transport = new WebFluxSseServerTransport(objectMapper, "/mcp/message"); + +// Create a server with the transport +var server = McpServer.sync(transport) + .serverInfo("my-server", "1.0.0") + .tool(calculatorTool, args -> new CallToolResult("Result: " + calculate(args))) + .rootsChangeConsumers(List.of( + roots -> System.out.println("Roots changed: " + roots) + )) + .build(); + +// Get client capabilities directly from server +ClientCapabilities capabilities = server.getClientCapabilities(); +``` + +#### After (0.8.0): + +```java +// Create a transport provider +McpServerTransportProvider transportProvider = new WebFluxSseServerTransportProvider(objectMapper, "/mcp/message"); + +// Create a server with the transport provider +var server = McpServer.sync(transportProvider) + .serverInfo("my-server", "1.0.0") + .tool(calculatorTool, (exchange, args) -> { + // Get client capabilities from exchange + ClientCapabilities capabilities = exchange.getClientCapabilities(); + return new CallToolResult("Result: " + calculate(args)); + }) + .rootsChangeHandlers(List.of( + (exchange, roots) -> System.out.println("Roots changed: " + roots) + )) + .build(); +``` + +### Example 2: Implementing a Tool with Client Interaction + +#### Before (0.7.0): + +```java +McpServerFeatures.SyncToolRegistration tool = new McpServerFeatures.SyncToolRegistration( + new Tool("weather", "Get weather information", schema), + args -> { + String location = (String) args.get("location"); + // Cannot interact with client from here + return new CallToolResult("Weather for " + location + ": Sunny"); + } +); + +var server = McpServer.sync(transport) + .tools(tool) + .build(); + +// Separate call to create a message +CreateMessageResult result = server.createMessage(new CreateMessageRequest(...)); +``` + +#### After (0.8.0): + +```java +McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( + new Tool("weather", "Get weather information", schema), + (exchange, args) -> { + String location = (String) args.get("location"); + + // Can interact with client directly from the tool handler + CreateMessageResult result = exchange.createMessage(new CreateMessageRequest(...)); + + return new CallToolResult("Weather for " + location + ": " + result.content()); + } +); + +var server = McpServer.sync(transportProvider) + .tools(tool) + .build(); +``` + +### Example 3: Converting Existing Registration Classes + +If you have custom implementations of the registration classes, you can convert them to the new specification classes: + +#### Before (0.7.0): + +```java +McpServerFeatures.AsyncToolRegistration toolReg = new McpServerFeatures.AsyncToolRegistration( + tool, + args -> Mono.just(new CallToolResult("Result")) +); + +McpServerFeatures.AsyncResourceRegistration resourceReg = new McpServerFeatures.AsyncResourceRegistration( + resource, + req -> Mono.just(new ReadResourceResult(List.of())) +); +``` + +#### After (0.8.0): + +```java +// Option 1: Create new specification directly +McpServerFeatures.AsyncToolSpecification toolSpec = new McpServerFeatures.AsyncToolSpecification( + tool, + (exchange, args) -> Mono.just(new CallToolResult("Result")) +); + +// Option 2: Convert from existing registration (during transition) +McpServerFeatures.AsyncToolRegistration oldToolReg = /* existing registration */; +McpServerFeatures.AsyncToolSpecification toolSpec = oldToolReg.toSpecification(); + +// Similarly for resources +McpServerFeatures.AsyncResourceSpecification resourceSpec = new McpServerFeatures.AsyncResourceSpecification( + resource, + (exchange, req) -> Mono.just(new ReadResourceResult(List.of())) +); +``` + +## Architecture Changes + +### Session-Based Architecture + +In 0.8.0, the MCP Java SDK introduces a session-based architecture where each client connection has its own session. This allows for better isolation between clients and more efficient resource management. + +The `McpServerTransportProvider` is responsible for creating `McpServerTransport` instances for each session, and the `McpServerSession` manages the communication with a specific client. + +### Exchange Objects + +The new exchange objects (`McpAsyncServerExchange` and `McpSyncServerExchange`) provide access to client-specific information and methods. They are passed to handler functions as the first parameter, allowing handlers to interact with the specific client that made the request. + +## Conclusion + +The changes in version 0.8.0 represent a significant architectural improvement to the MCP Java SDK. While they require some code changes, the new design provides a more flexible and maintainable foundation for building MCP applications. + +For assistance with migration or to report issues, please open an issue on the GitHub repository. From e0e08fecadc82403941b1f29cdb799f8fd025743 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Thu, 20 Mar 2025 15:35:04 +0100 Subject: [PATCH 20/20] improve migration guide Signed-off-by: Christian Tzolov --- migration-0.8.0.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/migration-0.8.0.md b/migration-0.8.0.md index 2e0b859e..3ba29a10 100644 --- a/migration-0.8.0.md +++ b/migration-0.8.0.md @@ -2,7 +2,9 @@ This document outlines the breaking changes and provides guidance on how to migrate your code from version 0.7.0 to 0.8.0. -The 0.8.0 refactoring introduces a robust session-based architecture for server-side MCP implementations, to improve the SDK's ability to handle multiple concurrent client connections and provide a more consistent API. The main changes include: +The 0.8.0 refactoring introduces a session-based architecture for server-side MCP implementations. +It improves the SDK's ability to handle multiple concurrent client connections and provides an API better aligned with the MCP specification. +The main changes include: 1. Introduction of a session-based architecture 2. New transport provider abstraction