diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java index 8ea65fd7..b0dfa89c 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java @@ -9,7 +9,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; @@ -58,7 +58,7 @@ * "https://spec.modelcontextprotocol.io/specification/basic/transports/#http-with-sse">MCP * HTTP with SSE Transport Specification */ -public class WebFluxSseClientTransport implements ClientMcpTransport { +public class WebFluxSseClientTransport implements McpClientTransport { private static final Logger logger = LoggerFactory.getLogger(WebFluxSseClientTransport.class); diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java index bed7293e..fb0b581e 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransport.java @@ -60,7 +60,10 @@ * @author Alexandros Pappas * @see ServerMcpTransport * @see ServerSentEvent + * @deprecated This class will be removed in 0.9.0. Use + * {@link WebFluxSseServerTransportProvider}. */ +@Deprecated public class WebFluxSseServerTransport implements ServerMcpTransport { private static final Logger logger = LoggerFactory.getLogger(WebFluxSseServerTransport.class); @@ -182,16 +185,16 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { try {// @formatter:off String jsonText = objectMapper.writeValueAsString(message); ServerSentEvent event = ServerSentEvent.builder() - .event(MESSAGE_EVENT_TYPE) - .data(jsonText) - .build(); + .event(MESSAGE_EVENT_TYPE) + .data(jsonText) + .build(); logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); List failedSessions = sessions.values().stream() - .filter(session -> session.messageSink.tryEmitNext(event).isFailure()) - .map(session -> session.id) - .toList(); + .filter(session -> session.messageSink.tryEmitNext(event).isFailure()) + .map(session -> session.id) + .toList(); if (failedSessions.isEmpty()) { logger.debug("Successfully broadcast message to all sessions"); @@ -407,4 +410,4 @@ void close() { } -} +} \ No newline at end of file diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java new file mode 100644 index 00000000..cf3eeae0 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -0,0 +1,351 @@ +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Exceptions; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; + +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.reactive.function.server.ServerResponse; + +/** + * Server-side implementation of the MCP (Model Context Protocol) HTTP transport using + * Server-Sent Events (SSE). This implementation provides a bidirectional communication + * channel between MCP clients and servers using HTTP POST for client-to-server messages + * and SSE for server-to-client messages. + * + *

+ * Key features: + *

    + *
  • Implements the {@link McpServerTransportProvider} interface that allows managing + * {@link McpServerSession} instances and enabling their communication with the + * {@link McpServerTransport} abstraction.
  • + *
  • Uses WebFlux for non-blocking request handling and SSE support
  • + *
  • Maintains client sessions for reliable message delivery
  • + *
  • Supports graceful shutdown with session cleanup
  • + *
  • Thread-safe message broadcasting to multiple clients
  • + *
+ * + *

+ * The transport sets up two main endpoints: + *

    + *
  • SSE endpoint (/sse) - For establishing SSE connections with clients
  • + *
  • Message endpoint (configurable) - For receiving JSON-RPC messages from clients
  • + *
+ * + *

+ * This implementation is thread-safe and can handle multiple concurrent client + * connections. It uses {@link ConcurrentHashMap} for session management and Project + * Reactor's non-blocking APIs for message processing and delivery. + * + * @author Christian Tzolov + * @author Alexandros Pappas + * @author Dariusz Jędrzejczyk + * @see McpServerTransport + * @see ServerSentEvent + */ +public class WebFluxSseServerTransportProvider implements McpServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(WebFluxSseServerTransportProvider.class); + + /** + * Event type for JSON-RPC messages sent through the SSE connection. + */ + public static final String MESSAGE_EVENT_TYPE = "message"; + + /** + * Event type for sending the message endpoint URI to clients. + */ + public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + /** + * Default SSE endpoint path as specified by the MCP transport specification. + */ + public static final String DEFAULT_SSE_ENDPOINT = "/sse"; + + private final ObjectMapper objectMapper; + + private final String messageEndpoint; + + private final String sseEndpoint; + + private final RouterFunction routerFunction; + + private McpServerSession.Factory sessionFactory; + + /** + * Map of active client sessions, keyed by session ID. + */ + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + /** + * Flag indicating if the transport is shutting down. + */ + private volatile boolean isClosing = false; + + /** + * Constructs a new WebFlux SSE server transport provider instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of MCP messages. Must not be null. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages. This endpoint will be communicated to clients during SSE connection + * setup. Must not be null. + * @throws IllegalArgumentException if either parameter is null + */ + public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); + + this.objectMapper = objectMapper; + this.messageEndpoint = messageEndpoint; + this.sseEndpoint = sseEndpoint; + this.routerFunction = RouterFunctions.route() + .GET(this.sseEndpoint, this::handleSseConnection) + .POST(this.messageEndpoint, this::handleMessage) + .build(); + } + + /** + * Constructs a new WebFlux SSE server transport provider instance with the default + * SSE endpoint. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of MCP messages. Must not be null. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages. This endpoint will be communicated to clients during SSE connection + * setup. Must not be null. + * @throws IllegalArgumentException if either parameter is null + */ + public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { + this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); + } + + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + /** + * Broadcasts a JSON-RPC message to all connected clients through their SSE + * connections. The message is serialized to JSON and sent as a server-sent event to + * each active session. + * + *

+ * The method: + *

    + *
  • Serializes the message to JSON
  • + *
  • Creates a server-sent event with the message data
  • + *
  • Attempts to send the event to all active sessions
  • + *
  • Tracks and reports any delivery failures
  • + *
+ * @param method The JSON-RPC method to send to clients + * @param params The method parameters to send to clients + * @return A Mono that completes when the message has been sent to all sessions, or + * errors if any session fails to receive the message + */ + @Override + public Mono notifyClients(String method, Map params) { + if (sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + + return Flux.fromStream(sessions.values().stream()) + .flatMap(session -> session.sendNotification(method, params) + .doOnError(e -> logger.error("Failed to " + "send message to session " + "{}: {}", session.getId(), + e.getMessage())) + .onErrorComplete()) + .then(); + } + + // FIXME: This javadoc makes claims about using isClosing flag but it's not actually + // doing that. + /** + * Initiates a graceful shutdown of all the sessions. This method ensures all active + * sessions are properly closed and cleaned up. + * + *

+ * The shutdown process: + *

    + *
  • Marks the transport as closing to prevent new connections
  • + *
  • Closes each active session
  • + *
  • Removes closed sessions from the sessions map
  • + *
  • Times out after 5 seconds if shutdown takes too long
  • + *
+ * @return A Mono that completes when all sessions have been closed + */ + @Override + public Mono closeGracefully() { + return Flux.fromIterable(sessions.values()) + .doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size())) + .flatMap(McpServerSession::closeGracefully) + .then(); + } + + /** + * Returns the WebFlux router function that defines the transport's HTTP endpoints. + * This router function should be integrated into the application's web configuration. + * + *

+ * The router function defines two endpoints: + *

    + *
  • GET {sseEndpoint} - For establishing SSE connections
  • + *
  • POST {messageEndpoint} - For receiving client messages
  • + *
+ * @return The configured {@link RouterFunction} for handling HTTP requests + */ + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + /** + * Handles new SSE connection requests from clients. Creates a new session for each + * connection and sets up the SSE event stream. + * @param request The incoming server request + * @return A Mono which emits a response with the SSE event stream + */ + private Mono handleSseConnection(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + return ServerResponse.ok() + .contentType(MediaType.TEXT_EVENT_STREAM) + .body(Flux.>create(sink -> { + WebFluxMcpSessionTransport sessionTransport = new WebFluxMcpSessionTransport(sink); + + McpServerSession session = sessionFactory.create(sessionTransport); + String sessionId = session.getId(); + + logger.debug("Created new SSE connection for session: {}", sessionId); + sessions.put(sessionId, session); + + // Send initial endpoint event + logger.debug("Sending initial endpoint event to session: {}", sessionId); + sink.next(ServerSentEvent.builder() + .event(ENDPOINT_EVENT_TYPE) + .data(messageEndpoint + "?sessionId=" + sessionId) + .build()); + sink.onCancel(() -> { + logger.debug("Session {} cancelled", sessionId); + sessions.remove(sessionId); + }); + }), ServerSentEvent.class); + } + + /** + * Handles incoming JSON-RPC messages from clients. Deserializes the message and + * processes it through the configured message handler. + * + *

+ * The handler: + *

    + *
  • Deserializes the incoming JSON-RPC message
  • + *
  • Passes it through the message handler chain
  • + *
  • Returns appropriate HTTP responses based on processing results
  • + *
  • Handles various error conditions with appropriate error responses
  • + *
+ * @param request The incoming server request containing the JSON-RPC message + * @return A Mono emitting the response indicating the message processing result + */ + private Mono handleMessage(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + if (request.queryParam("sessionId").isEmpty()) { + return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing in message endpoint")); + } + + McpServerSession session = sessions.get(request.queryParam("sessionId").get()); + + return request.bodyToMono(String.class).flatMap(body -> { + try { + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); + return session.handle(message).flatMap(response -> ServerResponse.ok().build()).onErrorResume(error -> { + logger.error("Error processing message: {}", error.getMessage()); + // TODO: instead of signalling the error, just respond with 200 OK + // - the error is signalled on the SSE connection + // return ServerResponse.ok().build(); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .bodyValue(new McpError(error.getMessage())); + }); + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format")); + } + }); + } + + private class WebFluxMcpSessionTransport implements McpServerTransport { + + private final FluxSink> sink; + + public WebFluxMcpSessionTransport(FluxSink> sink) { + this.sink = sink; + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return Mono.fromSupplier(() -> { + try { + return objectMapper.writeValueAsString(message); + } + catch (IOException e) { + throw Exceptions.propagate(e); + } + }).doOnNext(jsonText -> { + ServerSentEvent event = ServerSentEvent.builder() + .event(MESSAGE_EVENT_TYPE) + .data(jsonText) + .build(); + sink.next(event); + }).doOnError(e -> { + // TODO log with sessionid + Throwable exception = Exceptions.unwrap(e); + sink.error(exception); + }).then(); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(sink::complete); + } + + @Override + public void close() { + sink.complete(); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 4cd24c62..57bcd191 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -16,7 +16,7 @@ import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServerFeatures; -import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -30,9 +30,9 @@ import io.modelcontextprotocol.spec.McpSchema.Tool; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Mono; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; import reactor.test.StepVerifier; @@ -44,8 +44,8 @@ import org.springframework.web.reactive.function.server.RouterFunctions; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.awaitility.Awaitility.await; +import static org.mockito.Mockito.mock; public class WebFluxSseIntegrationTests { @@ -55,16 +55,16 @@ public class WebFluxSseIntegrationTests { private DisposableServer httpServer; - private WebFluxSseServerTransport mcpServerTransport; + private WebFluxSseServerTransportProvider mcpServerTransportProvider; ConcurrentHashMap clientBulders = new ConcurrentHashMap<>(); @BeforeEach public void before() { - this.mcpServerTransport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); - HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransport.getRouterFunction()); + HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransportProvider.getRouterFunction()); ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); @@ -84,57 +84,43 @@ public void after() { // --------------------------------------- // Sampling Tests // --------------------------------------- - @Test - void testCreateMessageWithoutInitialization() { - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized. Call the initialize method first!"); - }); - } - @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) void testCreateMessageWithoutSamplingCapabilities(String clientType) { - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - var clientBuilder = clientBulders.get(clientType); - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build(); + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - InitializeResult initResult = client.initialize(); - assertThat(initResult).isNotNull(); + exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + return Mono.just(mock(CallToolResult.class)); + }); + + McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); - StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) .hasMessage("Client must be configured with sampling capabilities"); - }); + } } @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) void testCreateMessageSuccess(String clientType) throws InterruptedException { + // Client var clientBuilder = clientBulders.get(clientType); - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - Function samplingHandler = request -> { assertThat(request.messages()).hasSize(1); assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); @@ -143,29 +129,54 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException { CreateMessageResult.StopReason.STOP_SEQUENCE); }; - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) .capabilities(ClientCapabilities.builder().sampling().build()) .sampling(samplingHandler) .build(); - InitializeResult initResult = client.initialize(); + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var messages = List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message"))); + var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + + var craeteMessageRequest = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, + McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), + Map.of()); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); assertThat(initResult).isNotNull(); - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); } // --------------------------------------- @@ -179,8 +190,8 @@ void testRootsSuccess(String clientType) { List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) @@ -192,8 +203,6 @@ void testRootsSuccess(String clientType) { assertThat(rootsRef.get()).isNull(); - assertThat(mcpServer.listRoots().roots()).containsAll(roots); - mcpClient.rootsListChangedNotification(); await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { @@ -222,23 +231,33 @@ void testRootsSuccess(String clientType) { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) void testRootsWithoutCapability(String clientType) { + var clientBuilder = clientBulders.get(clientType); - var mcpServer = McpServer.sync(mcpServerTransport).rootsChangeConsumer(rootsUpdate -> { - }).build(); + McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.listRoots(); // try to list roots + + return mock(CallToolResult.class); + }); + + var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { + }).tools(tool).build(); // Create client without roots capability - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()) // No - // roots - // capability - .build(); + // No roots capability + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + assertThat(mcpClient.initialize()).isNotNull(); // Attempt to list roots should fail - assertThatThrownBy(() -> mcpServer.listRoots().roots()).isInstanceOf(McpError.class) - .hasMessage("Roots not supported"); + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } mcpClient.close(); mcpServer.close(); @@ -246,12 +265,12 @@ void testRootsWithoutCapability(String clientType) { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsWithEmptyRootsList(String clientType) { + void testRootsNotifciationWithEmptyRootsList(String clientType) { var clientBuilder = clientBulders.get(clientType); AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) @@ -273,7 +292,7 @@ void testRootsWithEmptyRootsList(String clientType) { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsWithMultipleConsumers(String clientType) { + void testRootsWithMultipleHandlers(String clientType) { var clientBuilder = clientBulders.get(clientType); List roots = List.of(new Root("uri1://", "root1")); @@ -281,9 +300,9 @@ void testRootsWithMultipleConsumers(String clientType) { AtomicReference> rootsRef1 = new AtomicReference<>(); AtomicReference> rootsRef2 = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef1.set(rootsUpdate)) - .rootsChangeConsumer(rootsUpdate -> rootsRef2.set(rootsUpdate)) + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) .build(); var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) @@ -313,8 +332,8 @@ void testRootsServerCloseWithActiveSubscription(String clientType) { List roots = List.of(new Root("uri1://", "root1")); AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) @@ -356,8 +375,8 @@ void testToolCallSuccess(String clientType) { var clientBuilder = clientBulders.get(clientType); var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { + McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -368,7 +387,7 @@ void testToolCallSuccess(String clientType) { return callResponse; }); - var mcpServer = McpServer.sync(mcpServerTransport) + var mcpServer = McpServer.sync(mcpServerTransportProvider) .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(tool1) .build(); @@ -396,8 +415,8 @@ void testToolListChangeHandlingSuccess(String clientType) { var clientBuilder = clientBulders.get(clientType); var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { + McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -408,7 +427,7 @@ void testToolListChangeHandlingSuccess(String clientType) { return callResponse; }); - var mcpServer = McpServer.sync(mcpServerTransport) + var mcpServer = McpServer.sync(mcpServerTransportProvider) .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(tool1) .build(); @@ -446,8 +465,8 @@ void testToolListChangeHandlingSuccess(String clientType) { }); // Add a new tool - McpServerFeatures.SyncToolRegistration tool2 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), request -> callResponse); + McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), (exchange, request) -> callResponse); mcpServer.addTool(tool2); @@ -459,4 +478,21 @@ void testToolListChangeHandlingSuccess(String clientType) { mcpServer.close(); } + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testInitialize(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); + + var mcpClient = clientBuilder.build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.close(); + mcpServer.close(); + } + } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java index 0dccb27a..2dd587d4 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java @@ -7,7 +7,7 @@ import java.time.Duration; import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; @@ -32,7 +32,7 @@ class WebFluxSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected ClientMcpTransport createMcpTransport() { + protected McpClientTransport createMcpTransport() { return new WebFluxSseClientTransport(WebClient.builder().baseUrl(host)); } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java index f5cab7b7..72b390dd 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java @@ -7,7 +7,7 @@ import java.time.Duration; import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; @@ -32,7 +32,7 @@ class WebFluxSseMcpSyncClientTests extends AbstractMcpSyncClientTests { .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected ClientMcpTransport createMcpTransport() { + protected McpClientTransport createMcpTransport() { return new WebFluxSseClientTransport(WebClient.builder().baseUrl(host)); } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerDeprecatedTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerDeprecatedTests.java new file mode 100644 index 00000000..b460284e --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerDeprecatedTests.java @@ -0,0 +1,55 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.Timeout; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.server.RouterFunctions; + +/** + * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransport}. + * + * @author Christian Tzolov + */ +@Deprecated +@Timeout(15) // Giving extra time beyond the client timeout +class WebFluxSseMcpAsyncServerDeprecatedTests extends AbstractMcpAsyncServerDeprecatedTests { + + private static final int PORT = 8181; + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private DisposableServer httpServer; + + @Override + protected ServerMcpTransport createMcpTransport() { + var transport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(transport.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + + return transport; + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java index 1ed0d99b..5fa787ab 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java @@ -5,8 +5,8 @@ package io.modelcontextprotocol.server; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; @@ -16,7 +16,7 @@ import org.springframework.web.reactive.function.server.RouterFunctions; /** - * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransport}. + * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransportProvider}. * * @author Christian Tzolov */ @@ -30,14 +30,13 @@ class WebFluxSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { private DisposableServer httpServer; @Override - protected ServerMcpTransport createMcpTransport() { - var transport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + protected McpServerTransportProvider createMcpTransportProvider() { + var transportProvider = new WebFluxSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); - HttpHandler httpHandler = RouterFunctions.toHttpHandler(transport.getRouterFunction()); + HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); - - return transport; + return transportProvider; } @Override diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerDeprecatecTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerDeprecatecTests.java new file mode 100644 index 00000000..be2bf6c7 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerDeprecatecTests.java @@ -0,0 +1,55 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.Timeout; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.server.RouterFunctions; + +/** + * Tests for {@link McpSyncServer} using {@link WebFluxSseServerTransport}. + * + * @author Christian Tzolov + */ +@Deprecated +@Timeout(15) // Giving extra time beyond the client timeout +class WebFluxSseMcpSyncServerDeprecatecTests extends AbstractMcpSyncServerDeprecatedTests { + + private static final int PORT = 8182; + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private DisposableServer httpServer; + + private WebFluxSseServerTransport transport; + + @Override + protected ServerMcpTransport createMcpTransport() { + transport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + return transport; + } + + @Override + protected void onStart() { + HttpHandler httpHandler = RouterFunctions.toHttpHandler(transport.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + } + + @Override + protected void onClose() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java index 4db00dd4..d3672e3f 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java @@ -5,8 +5,8 @@ package io.modelcontextprotocol.server; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; @@ -16,7 +16,7 @@ import org.springframework.web.reactive.function.server.RouterFunctions; /** - * Tests for {@link McpSyncServer} using {@link WebFluxSseServerTransport}. + * Tests for {@link McpSyncServer} using {@link WebFluxSseServerTransportProvider}. * * @author Christian Tzolov */ @@ -29,17 +29,17 @@ class WebFluxSseMcpSyncServerTests extends AbstractMcpSyncServerTests { private DisposableServer httpServer; - private WebFluxSseServerTransport transport; + private WebFluxSseServerTransportProvider transportProvider; @Override - protected ServerMcpTransport createMcpTransport() { - transport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); - return transport; + protected McpServerTransportProvider createMcpTransportProvider() { + transportProvider = new WebFluxSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + return transportProvider; } @Override protected void onStart() { - HttpHandler httpHandler = RouterFunctions.toHttpHandler(transport.getRouterFunction()); + HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/legacy/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/legacy/WebFluxSseIntegrationTests.java new file mode 100644 index 00000000..981e114c --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/legacy/WebFluxSseIntegrationTests.java @@ -0,0 +1,459 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server.legacy; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; +import reactor.test.StepVerifier; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunctions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.awaitility.Awaitility.await; + +public class WebFluxSseIntegrationTests { + + private static final int PORT = 8182; + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private DisposableServer httpServer; + + private WebFluxSseServerTransport mcpServerTransport; + + ConcurrentHashMap clientBulders = new ConcurrentHashMap<>(); + + @BeforeEach + public void before() { + + this.mcpServerTransport = new WebFluxSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransport.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + + clientBulders.put("httpclient", McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT))); + clientBulders.put("webflux", + McpClient.sync(new WebFluxSseClientTransport(WebClient.builder().baseUrl("http://localhost:" + PORT)))); + + } + + @AfterEach + public void after() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + + // --------------------------------------- + // Sampling Tests + // --------------------------------------- + @Test + void testCreateMessageWithoutInitialization() { + var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); + + var messages = List.of(new McpSchema.SamplingMessage(Role.USER, new McpSchema.TextContent("Test message"))); + var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + + var request = new CreateMessageRequest(messages, modelPrefs, null, + CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); + + StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized. Call the initialize method first!"); + }); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageWithoutSamplingCapabilities(String clientType) { + + var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); + + var clientBuilder = clientBulders.get(clientType); + + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build(); + + InitializeResult initResult = client.initialize(); + assertThat(initResult).isNotNull(); + + var messages = List.of(new McpSchema.SamplingMessage(Role.USER, new McpSchema.TextContent("Test message"))); + var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + + var request = new CreateMessageRequest(messages, modelPrefs, null, + CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); + + StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + }); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageSuccess(String clientType) throws InterruptedException { + + var clientBuilder = clientBulders.get(clientType); + + var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + InitializeResult initResult = client.initialize(); + assertThat(initResult).isNotNull(); + + var messages = List.of(new McpSchema.SamplingMessage(Role.USER, new McpSchema.TextContent("Test message"))); + var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + + var request = new CreateMessageRequest(messages, modelPrefs, null, + CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); + + StepVerifier.create(mcpAsyncServer.createMessage(request)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsSuccess(String clientType) { + var clientBuilder = clientBulders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransport) + .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + assertThat(mcpServer.listRoots().roots()).containsAll(roots); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Remove a root + mcpClient.removeRoot(roots.get(0).uri()); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); + }); + + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsWithoutCapability(String clientType) { + var clientBuilder = clientBulders.get(clientType); + + var mcpServer = McpServer.sync(mcpServerTransport).rootsChangeConsumer(rootsUpdate -> { + }).build(); + + // Create client without roots capability + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()) // No + // roots + // capability + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Attempt to list roots should fail + assertThatThrownBy(() -> mcpServer.listRoots().roots()).isInstanceOf(McpError.class) + .hasMessage("Roots not supported"); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsWithEmptyRootsList(String clientType) { + var clientBuilder = clientBulders.get(clientType); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransport) + .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(List.of()) // Empty roots list + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsWithMultipleConsumers(String clientType) { + var clientBuilder = clientBulders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef1 = new AtomicReference<>(); + AtomicReference> rootsRef2 = new AtomicReference<>(); + + var mcpServer = McpServer.sync(mcpServerTransport) + .rootsChangeConsumer(rootsUpdate -> rootsRef1.set(rootsUpdate)) + .rootsChangeConsumer(rootsUpdate -> rootsRef2.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsServerCloseWithActiveSubscription(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransport) + .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Close server while subscription is active + mcpServer.close(); + + // Verify client can handle server closure gracefully + mcpClient.close(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolCallSuccess(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( + new Tool("tool1", "tool1 description", emptyJsonSchema), request -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpServerTransport) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + var mcpClient = clientBuilder.build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolListChangeHandlingSuccess(String clientType) { + + var clientBuilder = clientBulders.get(clientType); + + var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( + new Tool("tool1", "tool1 description", emptyJsonSchema), request -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpServerTransport) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + rootsRef.set(toolsUpdate); + }).build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + mcpServer.notifyToolsListChanged(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); + }); + + // Remove a tool + mcpServer.removeTool("tool1"); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + + // Add a new tool + McpServerFeatures.SyncToolRegistration tool2 = new McpServerFeatures.SyncToolRegistration( + new Tool("tool2", "tool2 description", emptyJsonSchema), request -> callResponse); + + mcpServer.addTool(tool2); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); + }); + + mcpClient.close(); + mcpServer.close(); + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransport.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransport.java index 00928ec7..23193d10 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransport.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransport.java @@ -33,6 +33,9 @@ * a bridge between synchronous WebMVC operations and reactive programming patterns to * maintain compatibility with the reactive transport interface. * + * @deprecated This class will be removed in 0.9.0. Use + * {@link WebMvcSseServerTransportProvider}. + * *

* Key features: *

    @@ -57,12 +60,12 @@ * This implementation uses {@link ConcurrentHashMap} to safely manage multiple client * sessions in a thread-safe manner. Each client session is assigned a unique ID and * maintains its own SSE connection. - * * @author Christian Tzolov * @author Alexandros Pappas * @see ServerMcpTransport * @see RouterFunction */ +@Deprecated public class WebMvcSseServerTransport implements ServerMcpTransport { private static final Logger logger = LoggerFactory.getLogger(WebMvcSseServerTransport.class); diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java new file mode 100644 index 00000000..65416b25 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -0,0 +1,399 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.time.Duration; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpStatus; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.RouterFunctions; +import org.springframework.web.servlet.function.ServerRequest; +import org.springframework.web.servlet.function.ServerResponse; +import org.springframework.web.servlet.function.ServerResponse.SseBuilder; + +/** + * Server-side implementation of the Model Context Protocol (MCP) transport layer using + * HTTP with Server-Sent Events (SSE) through Spring WebMVC. This implementation provides + * a bridge between synchronous WebMVC operations and reactive programming patterns to + * maintain compatibility with the reactive transport interface. + * + *

    + * Key features: + *

      + *
    • Implements bidirectional communication using HTTP POST for client-to-server + * messages and SSE for server-to-client messages
    • + *
    • Manages client sessions with unique IDs for reliable message delivery
    • + *
    • Supports graceful shutdown with proper session cleanup
    • + *
    • Provides JSON-RPC message handling through configured endpoints
    • + *
    • Includes built-in error handling and logging
    • + *
    + * + *

    + * The transport operates on two main endpoints: + *

      + *
    • {@code /sse} - The SSE endpoint where clients establish their event stream + * connection
    • + *
    • A configurable message endpoint where clients send their JSON-RPC messages via HTTP + * POST
    • + *
    + * + *

    + * This implementation uses {@link ConcurrentHashMap} to safely manage multiple client + * sessions in a thread-safe manner. Each client session is assigned a unique ID and + * maintains its own SSE connection. + * + * @author Christian Tzolov + * @author Alexandros Pappas + * @see McpServerTransportProvider + * @see RouterFunction + */ +public class WebMvcSseServerTransportProvider implements McpServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(WebMvcSseServerTransportProvider.class); + + /** + * Event type for JSON-RPC messages sent through the SSE connection. + */ + public static final String MESSAGE_EVENT_TYPE = "message"; + + /** + * Event type for sending the message endpoint URI to clients. + */ + public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + /** + * Default SSE endpoint path as specified by the MCP transport specification. + */ + public static final String DEFAULT_SSE_ENDPOINT = "/sse"; + + private final ObjectMapper objectMapper; + + private final String messageEndpoint; + + private final String sseEndpoint; + + private final RouterFunction routerFunction; + + private McpServerSession.Factory sessionFactory; + + /** + * Map of active client sessions, keyed by session ID. + */ + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + /** + * Flag indicating if the transport is shutting down. + */ + private volatile boolean isClosing = false; + + /** + * Constructs a new WebMvcSseServerTransportProvider instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of messages. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP POST. This endpoint will be communicated to clients through the + * SSE connection's initial endpoint event. + * @param sseEndpoint The endpoint URI where clients establish their SSE connections. + * @throws IllegalArgumentException if any parameter is null + */ + public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); + + this.objectMapper = objectMapper; + this.messageEndpoint = messageEndpoint; + this.sseEndpoint = sseEndpoint; + this.routerFunction = RouterFunctions.route() + .GET(this.sseEndpoint, this::handleSseConnection) + .POST(this.messageEndpoint, this::handleMessage) + .build(); + } + + /** + * Constructs a new WebMvcSseServerTransportProvider instance with the default SSE + * endpoint. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of messages. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP POST. This endpoint will be communicated to clients through the + * SSE connection's initial endpoint event. + * @throws IllegalArgumentException if either objectMapper or messageEndpoint is null + */ + public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { + this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); + } + + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + /** + * Broadcasts a notification to all connected clients through their SSE connections. + * The message is serialized to JSON and sent as an SSE event with type "message". If + * any errors occur during sending to a particular client, they are logged but don't + * prevent sending to other clients. + * @param method The method name for the notification + * @param params The parameters for the notification + * @return A Mono that completes when the broadcast attempt is finished + */ + @Override + public Mono notifyClients(String method, Map params) { + if (sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + + return Flux.fromIterable(sessions.values()) + .flatMap(session -> session.sendNotification(method, params) + .doOnError( + e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())) + .onErrorComplete()) + .then(); + } + + /** + * Initiates a graceful shutdown of the transport. This method: + *

      + *
    • Sets the closing flag to prevent new connections
    • + *
    • Closes all active SSE connections
    • + *
    • Removes all session records
    • + *
    + * @return A Mono that completes when all cleanup operations are finished + */ + @Override + public Mono closeGracefully() { + return Flux.fromIterable(sessions.values()).doFirst(() -> { + this.isClosing = true; + logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); + }) + .flatMap(McpServerSession::closeGracefully) + .then() + .doOnSuccess(v -> logger.debug("Graceful shutdown completed")); + } + + /** + * Returns the RouterFunction that defines the HTTP endpoints for this transport. The + * router function handles two endpoints: + *
      + *
    • GET /sse - For establishing SSE connections
    • + *
    • POST [messageEndpoint] - For receiving JSON-RPC messages from clients
    • + *
    + * @return The configured RouterFunction for handling HTTP requests + */ + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + /** + * Handles new SSE connection requests from clients by creating a new session and + * establishing an SSE connection. This method: + *
      + *
    • Generates a unique session ID
    • + *
    • Creates a new session with a WebMvcMcpSessionTransport
    • + *
    • Sends an initial endpoint event to inform the client where to send + * messages
    • + *
    • Maintains the session in the sessions map
    • + *
    + * @param request The incoming server request + * @return A ServerResponse configured for SSE communication, or an error response if + * the server is shutting down or the connection fails + */ + private ServerResponse handleSseConnection(ServerRequest request) { + if (this.isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); + } + + String sessionId = UUID.randomUUID().toString(); + logger.debug("Creating new SSE connection for session: {}", sessionId); + + // Send initial endpoint event + try { + return ServerResponse.sse(sseBuilder -> { + sseBuilder.onComplete(() -> { + logger.debug("SSE connection completed for session: {}", sessionId); + sessions.remove(sessionId); + }); + sseBuilder.onTimeout(() -> { + logger.debug("SSE connection timed out for session: {}", sessionId); + sessions.remove(sessionId); + }); + + WebMvcMcpSessionTransport sessionTransport = new WebMvcMcpSessionTransport(sessionId, sseBuilder); + McpServerSession session = sessionFactory.create(sessionTransport); + this.sessions.put(sessionId, session); + + try { + sseBuilder.id(sessionId) + .event(ENDPOINT_EVENT_TYPE) + .data(messageEndpoint + "?sessionId=" + sessionId); + } + catch (Exception e) { + logger.error("Failed to send initial endpoint event: {}", e.getMessage()); + sseBuilder.error(e); + } + }, Duration.ZERO); + } + catch (Exception e) { + logger.error("Failed to send initial endpoint event to session {}: {}", sessionId, e.getMessage()); + sessions.remove(sessionId); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Handles incoming JSON-RPC messages from clients. This method: + *
      + *
    • Deserializes the request body into a JSON-RPC message
    • + *
    • Processes the message through the session's handle method
    • + *
    • Returns appropriate HTTP responses based on the processing result
    • + *
    + * @param request The incoming server request containing the JSON-RPC message + * @return A ServerResponse indicating success (200 OK) or appropriate error status + * with error details in case of failures + */ + private ServerResponse handleMessage(ServerRequest request) { + if (this.isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); + } + + if (!request.param("sessionId").isPresent()) { + return ServerResponse.badRequest().body(new McpError("Session ID missing in message endpoint")); + } + + String sessionId = request.param("sessionId").get(); + McpServerSession session = sessions.get(sessionId); + + if (session == null) { + return ServerResponse.status(HttpStatus.NOT_FOUND).body(new McpError("Session not found: " + sessionId)); + } + + try { + String body = request.body(String.class); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); + + // Process the message through the session's handle method + session.handle(message).block(); // Block for WebMVC compatibility + + return ServerResponse.ok().build(); + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + return ServerResponse.badRequest().body(new McpError("Invalid message format")); + } + catch (Exception e) { + logger.error("Error handling message: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).body(new McpError(e.getMessage())); + } + } + + /** + * Implementation of McpServerTransport for WebMVC SSE sessions. This class handles + * the transport-level communication for a specific client session. + */ + private class WebMvcMcpSessionTransport implements McpServerTransport { + + private final String sessionId; + + private final SseBuilder sseBuilder; + + /** + * Creates a new session transport with the specified ID and SSE builder. + * @param sessionId The unique identifier for this session + * @param sseBuilder The SSE builder for sending server events to the client + */ + WebMvcMcpSessionTransport(String sessionId, SseBuilder sseBuilder) { + this.sessionId = sessionId; + this.sseBuilder = sseBuilder; + logger.debug("Session transport {} initialized with SSE builder", sessionId); + } + + /** + * Sends a JSON-RPC message to the client through the SSE connection. + * @param message The JSON-RPC message to send + * @return A Mono that completes when the message has been sent + */ + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return Mono.fromRunnable(() -> { + try { + String jsonText = objectMapper.writeValueAsString(message); + sseBuilder.id(sessionId).event(MESSAGE_EVENT_TYPE).data(jsonText); + logger.debug("Message sent to session {}", sessionId); + } + catch (Exception e) { + logger.error("Failed to send message to session {}: {}", sessionId, e.getMessage()); + sseBuilder.error(e); + } + }); + } + + /** + * Converts data from one type to another using the configured ObjectMapper. + * @param data The source data object to convert + * @param typeRef The target type reference + * @return The converted object of type T + * @param The target type + */ + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + /** + * Initiates a graceful shutdown of the transport. + * @return A Mono that completes when the shutdown is complete + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + logger.debug("Closing session transport: {}", sessionId); + try { + sseBuilder.complete(); + logger.debug("Successfully completed SSE builder for session {}", sessionId); + } + catch (Exception e) { + logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage()); + } + }); + } + + /** + * Closes the transport immediately. + */ + @Override + public void close() { + try { + sseBuilder.complete(); + logger.debug("Successfully completed SSE builder for session {}", sessionId); + } + catch (Exception e) { + logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage()); + } + } + + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportDeprecatedTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportDeprecatedTests.java new file mode 100644 index 00000000..c3f0e322 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportDeprecatedTests.java @@ -0,0 +1,118 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.Timeout; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.servlet.DispatcherServlet; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +@Deprecated +@Timeout(15) +class WebMvcSseAsyncServerTransportDeprecatedTests extends AbstractMcpAsyncServerDeprecatedTests { + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private static final int PORT = 8181; + + private Tomcat tomcat; + + private WebMvcSseServerTransport transport; + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcSseServerTransport webMvcSseServerTransport() { + return new WebMvcSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + } + + @Bean + public RouterFunction routerFunction(WebMvcSseServerTransport transport) { + return transport.getRouterFunction(); + } + + } + + private AnnotationConfigWebApplicationContext appContext; + + @Override + protected ServerMcpTransport createMcpTransport() { + // Set up Tomcat first + tomcat = new Tomcat(); + tomcat.setPort(PORT); + + // Set Tomcat base directory to java.io.tmpdir to avoid permission issues + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + // Use the same directory for document base + Context context = tomcat.addContext("", baseDir); + + // Create and configure Spring WebMvc context + appContext = new AnnotationConfigWebApplicationContext(); + appContext.register(TestConfig.class); + appContext.setServletContext(context.getServletContext()); + appContext.refresh(); + + // Get the transport from Spring context + transport = appContext.getBean(WebMvcSseServerTransport.class); + + // Create DispatcherServlet with our Spring context + DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); + // dispatcherServlet.setThrowExceptionIfNoHandlerFound(true); + + // Add servlet to Tomcat and get the wrapper + var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); + wrapper.setLoadOnStartup(1); + context.addServletMappingDecoded("/*", "dispatcherServlet"); + + try { + tomcat.start(); + tomcat.getConnector(); // Create and start the connector + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + return transport; + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (transport != null) { + transport.closeGracefully().block(); + } + if (appContext != null) { + appContext.close(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java index a819920c..08d5de67 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java @@ -5,8 +5,8 @@ package io.modelcontextprotocol.server; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebMvcSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.apache.catalina.Context; import org.apache.catalina.LifecycleException; import org.apache.catalina.startup.Tomcat; @@ -29,20 +29,20 @@ class WebMvcSseAsyncServerTransportTests extends AbstractMcpAsyncServerTests { private Tomcat tomcat; - private WebMvcSseServerTransport transport; + private McpServerTransportProvider transportProvider; @Configuration @EnableWebMvc static class TestConfig { @Bean - public WebMvcSseServerTransport webMvcSseServerTransport() { - return new WebMvcSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { + return new WebMvcSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); } @Bean - public RouterFunction routerFunction(WebMvcSseServerTransport transport) { - return transport.getRouterFunction(); + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); } } @@ -50,7 +50,7 @@ public RouterFunction routerFunction(WebMvcSseServerTransport tr private AnnotationConfigWebApplicationContext appContext; @Override - protected ServerMcpTransport createMcpTransport() { + protected McpServerTransportProvider createMcpTransportProvider() { // Set up Tomcat first tomcat = new Tomcat(); tomcat.setPort(PORT); @@ -69,7 +69,7 @@ protected ServerMcpTransport createMcpTransport() { appContext.refresh(); // Get the transport from Spring context - transport = appContext.getBean(WebMvcSseServerTransport.class); + transportProvider = appContext.getBean(WebMvcSseServerTransportProvider.class); // Create DispatcherServlet with our Spring context DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); @@ -88,7 +88,7 @@ protected ServerMcpTransport createMcpTransport() { throw new RuntimeException("Failed to start Tomcat", e); } - return transport; + return transportProvider; } @Override @@ -97,8 +97,8 @@ protected void onStart() { @Override protected void onClose() { - if (transport != null) { - transport.closeGracefully().block(); + if (transportProvider != null) { + transportProvider.closeGracefully().block(); } if (appContext != null) { appContext.close(); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationDeprecatedTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationDeprecatedTests.java new file mode 100644 index 00000000..f2b593d8 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationDeprecatedTests.java @@ -0,0 +1,508 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.test.StepVerifier; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.client.RestClient; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.servlet.DispatcherServlet; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.awaitility.Awaitility.await; + +@Deprecated +public class WebMvcSseIntegrationDeprecatedTests { + + private static final int PORT = 8183; + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private WebMvcSseServerTransport mcpServerTransport; + + McpClient.SyncSpec clientBuilder; + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcSseServerTransport webMvcSseServerTransport() { + return new WebMvcSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + } + + @Bean + public RouterFunction routerFunction(WebMvcSseServerTransport transport) { + return transport.getRouterFunction(); + } + + } + + private Tomcat tomcat; + + private AnnotationConfigWebApplicationContext appContext; + + @BeforeEach + public void before() { + + // Set up Tomcat first + tomcat = new Tomcat(); + tomcat.setPort(PORT); + + // Set Tomcat base directory to java.io.tmpdir to avoid permission issues + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + // Use the same directory for document base + Context context = tomcat.addContext("", baseDir); + + // Create and configure Spring WebMvc context + appContext = new AnnotationConfigWebApplicationContext(); + appContext.register(TestConfig.class); + appContext.setServletContext(context.getServletContext()); + appContext.refresh(); + + // Get the transport from Spring context + mcpServerTransport = appContext.getBean(WebMvcSseServerTransport.class); + + // Create DispatcherServlet with our Spring context + DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); + // dispatcherServlet.setThrowExceptionIfNoHandlerFound(true); + + // Add servlet to Tomcat and get the wrapper + var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); + wrapper.setLoadOnStartup(1); + wrapper.setAsyncSupported(true); + context.addServletMappingDecoded("/*", "dispatcherServlet"); + + try { + // Configure and start the connector with async support + var connector = tomcat.getConnector(); + connector.setAsyncTimeout(3000); // 3 seconds timeout for async requests + tomcat.start(); + assertThat(tomcat.getServer().getState() == LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + this.clientBuilder = McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT)); + } + + @AfterEach + public void after() { + if (mcpServerTransport != null) { + mcpServerTransport.closeGracefully().block(); + } + if (appContext != null) { + appContext.close(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + // --------------------------------------- + // Sampling Tests + // --------------------------------------- + @Test + void testCreateMessageWithoutInitialization() { + var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); + + var messages = List + .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); + var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + + var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, + McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); + + StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Client must be initialized. Call the initialize method first!"); + }); + } + + @Test + void testCreateMessageWithoutSamplingCapabilities() { + + var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); + + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); + + InitializeResult initResult = client.initialize(); + assertThat(initResult).isNotNull(); + + var messages = List + .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); + var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + + var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, + McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); + + StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + }); + } + + @Test + void testCreateMessageSuccess() throws InterruptedException { + + var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + InitializeResult initResult = client.initialize(); + assertThat(initResult).isNotNull(); + + var messages = List + .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); + var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + + var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, + McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); + + StepVerifier.create(mcpAsyncServer.createMessage(request)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + @Test + void testRootsSuccess() { + List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransport) + .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + assertThat(mcpServer.listRoots().roots()).containsAll(roots); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Remove a root + mcpClient.removeRoot(roots.get(0).uri()); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); + }); + + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testRootsWithoutCapability() { + var mcpServer = McpServer.sync(mcpServerTransport).rootsChangeConsumer(rootsUpdate -> { + }).build(); + + // Create client without roots capability + // No roots capability + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Attempt to list roots should fail + assertThatThrownBy(() -> mcpServer.listRoots().roots()).isInstanceOf(McpError.class) + .hasMessage("Roots not supported"); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testRootsWithEmptyRootsList() { + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransport) + .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(List.of()) // Empty roots list + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testRootsWithMultipleConsumers() { + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef1 = new AtomicReference<>(); + AtomicReference> rootsRef2 = new AtomicReference<>(); + + var mcpServer = McpServer.sync(mcpServerTransport) + .rootsChangeConsumer(rootsUpdate -> rootsRef1.set(rootsUpdate)) + .rootsChangeConsumer(rootsUpdate -> rootsRef2.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testRootsServerCloseWithActiveSubscription() { + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransport) + .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Close server while subscription is active + mcpServer.close(); + + // Verify client can handle server closure gracefully + mcpClient.close(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @Test + void testToolCallSuccess() { + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpServerTransport) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + var mcpClient = clientBuilder.build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testToolListChangeHandlingSuccess() { + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpServerTransport) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + rootsRef.set(toolsUpdate); + }).build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + mcpServer.notifyToolsListChanged(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); + }); + + // Remove a tool + mcpServer.removeTool("tool1"); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + + // Add a new tool + McpServerFeatures.SyncToolRegistration tool2 = new McpServerFeatures.SyncToolRegistration( + new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), request -> callResponse); + + mcpServer.addTool(tool2); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testInitialize() { + + var mcpServer = McpServer.sync(mcpServerTransport).build(); + + var mcpClient = clientBuilder.build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.close(); + mcpServer.close(); + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index 62f69637..7ba9ccc1 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -12,7 +12,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.server.transport.WebMvcSseServerTransport; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -31,6 +31,7 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; import reactor.test.StepVerifier; import org.springframework.context.annotation.Bean; @@ -43,8 +44,8 @@ import org.springframework.web.servlet.function.ServerResponse; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.awaitility.Awaitility.await; +import static org.mockito.Mockito.mock; public class WebMvcSseIntegrationTests { @@ -52,7 +53,7 @@ public class WebMvcSseIntegrationTests { private static final String MESSAGE_ENDPOINT = "/mcp/message"; - private WebMvcSseServerTransport mcpServerTransport; + private WebMvcSseServerTransportProvider mcpServerTransportProvider; McpClient.SyncSpec clientBuilder; @@ -61,13 +62,13 @@ public class WebMvcSseIntegrationTests { static class TestConfig { @Bean - public WebMvcSseServerTransport webMvcSseServerTransport() { - return new WebMvcSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { + return new WebMvcSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); } @Bean - public RouterFunction routerFunction(WebMvcSseServerTransport transport) { - return transport.getRouterFunction(); + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); } } @@ -97,7 +98,7 @@ public void before() { appContext.refresh(); // Get the transport from Spring context - mcpServerTransport = appContext.getBean(WebMvcSseServerTransport.class); + mcpServerTransportProvider = appContext.getBean(WebMvcSseServerTransportProvider.class); // Create DispatcherServlet with our Spring context DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); @@ -125,8 +126,8 @@ public void before() { @AfterEach public void after() { - if (mcpServerTransport != null) { - mcpServerTransport.closeGracefully().block(); + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); } if (appContext != null) { appContext.close(); @@ -146,49 +147,36 @@ public void after() { // Sampling Tests // --------------------------------------- @Test - void testCreateMessageWithoutInitialization() { - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); - - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + void testCreateMessageWithoutSamplingCapabilities() { - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) - .hasMessage("Client must be initialized. Call the initialize method first!"); - }); - } + exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); - @Test - void testCreateMessageWithoutSamplingCapabilities() { + return Mono.just(mock(CallToolResult.class)); + }); - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); + McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + // Create client without sampling capabilities var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); - InitializeResult initResult = client.initialize(); - assertThat(initResult).isNotNull(); - - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + assertThat(client.initialize()).isNotNull(); - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class) + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) .hasMessage("Client must be configured with sampling capabilities"); - }); + } } @Test void testCreateMessageSuccess() throws InterruptedException { - var mcpAsyncServer = McpServer.async(mcpServerTransport).serverInfo("test-server", "1.0.0").build(); + // Client Function samplingHandler = request -> { assertThat(request.messages()).hasSize(1); @@ -198,29 +186,54 @@ void testCreateMessageSuccess() throws InterruptedException { CreateMessageResult.StopReason.STOP_SEQUENCE); }; - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) .capabilities(ClientCapabilities.builder().sampling().build()) .sampling(samplingHandler) .build(); - InitializeResult initResult = client.initialize(); + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var messages = List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message"))); + var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + + var craeteMessageRequest = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, + McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), + Map.of()); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); assertThat(initResult).isNotNull(); - var messages = List - .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var request = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), Map.of()); - - StepVerifier.create(mcpAsyncServer.createMessage(request)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); } // --------------------------------------- @@ -231,8 +244,8 @@ void testRootsSuccess() { List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) @@ -244,8 +257,6 @@ void testRootsSuccess() { assertThat(rootsRef.get()).isNull(); - assertThat(mcpServer.listRoots().roots()).containsAll(roots); - mcpClient.rootsListChangedNotification(); await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { @@ -273,29 +284,42 @@ void testRootsSuccess() { @Test void testRootsWithoutCapability() { - var mcpServer = McpServer.sync(mcpServerTransport).rootsChangeConsumer(rootsUpdate -> { - }).build(); + + McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.listRoots(); // try to list roots + + return mock(CallToolResult.class); + }); + + var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { + }).tools(tool).build(); // Create client without roots capability // No roots capability var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + assertThat(mcpClient.initialize()).isNotNull(); // Attempt to list roots should fail - assertThatThrownBy(() -> mcpServer.listRoots().roots()).isInstanceOf(McpError.class) - .hasMessage("Roots not supported"); + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } mcpClient.close(); mcpServer.close(); } @Test - void testRootsWithEmptyRootsList() { + void testRootsNotifciationWithEmptyRootsList() { AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) @@ -316,23 +340,22 @@ void testRootsWithEmptyRootsList() { } @Test - void testRootsWithMultipleConsumers() { + void testRootsWithMultipleHandlers() { List roots = List.of(new Root("uri1://", "root1")); AtomicReference> rootsRef1 = new AtomicReference<>(); AtomicReference> rootsRef2 = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef1.set(rootsUpdate)) - .rootsChangeConsumer(rootsUpdate -> rootsRef2.set(rootsUpdate)) + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) .build(); var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) .roots(roots) .build(); - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); + assertThat(mcpClient.initialize()).isNotNull(); mcpClient.rootsListChangedNotification(); @@ -350,8 +373,8 @@ void testRootsServerCloseWithActiveSubscription() { List roots = List.of(new Root("uri1://", "root1")); AtomicReference> rootsRef = new AtomicReference<>(); - var mcpServer = McpServer.sync(mcpServerTransport) - .rootsChangeConsumer(rootsUpdate -> rootsRef.set(rootsUpdate)) + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) .build(); var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) @@ -390,8 +413,8 @@ void testRootsServerCloseWithActiveSubscription() { void testToolCallSuccess() { var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { + McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -402,7 +425,7 @@ void testToolCallSuccess() { return callResponse; }); - var mcpServer = McpServer.sync(mcpServerTransport) + var mcpServer = McpServer.sync(mcpServerTransportProvider) .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(tool1) .build(); @@ -427,8 +450,8 @@ void testToolCallSuccess() { void testToolListChangeHandlingSuccess() { var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolRegistration tool1 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), request -> { + McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { // perform a blocking call to a remote service String response = RestClient.create() .get() @@ -439,7 +462,7 @@ void testToolListChangeHandlingSuccess() { return callResponse; }); - var mcpServer = McpServer.sync(mcpServerTransport) + var mcpServer = McpServer.sync(mcpServerTransportProvider) .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(tool1) .build(); @@ -477,8 +500,8 @@ void testToolListChangeHandlingSuccess() { }); // Add a new tool - McpServerFeatures.SyncToolRegistration tool2 = new McpServerFeatures.SyncToolRegistration( - new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), request -> callResponse); + McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), (exchange, request) -> callResponse); mcpServer.addTool(tool2); @@ -493,7 +516,7 @@ void testToolListChangeHandlingSuccess() { @Test void testInitialize() { - var mcpServer = McpServer.sync(mcpServerTransport).build(); + var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); var mcpClient = clientBuilder.build(); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportDeprecatedTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportDeprecatedTests.java new file mode 100644 index 00000000..8656665e --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportDeprecatedTests.java @@ -0,0 +1,118 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.Timeout; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.servlet.DispatcherServlet; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +@Deprecated +@Timeout(15) +class WebMvcSseSyncServerTransportDeprecatedTests extends AbstractMcpSyncServerDeprecatedTests { + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private static final int PORT = 8181; + + private Tomcat tomcat; + + private WebMvcSseServerTransport transport; + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcSseServerTransport webMvcSseServerTransport() { + return new WebMvcSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + } + + @Bean + public RouterFunction routerFunction(WebMvcSseServerTransport transport) { + return transport.getRouterFunction(); + } + + } + + private AnnotationConfigWebApplicationContext appContext; + + @Override + protected ServerMcpTransport createMcpTransport() { + // Set up Tomcat first + tomcat = new Tomcat(); + tomcat.setPort(PORT); + + // Set Tomcat base directory to java.io.tmpdir to avoid permission issues + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + // Use the same directory for document base + Context context = tomcat.addContext("", baseDir); + + // Create and configure Spring WebMvc context + appContext = new AnnotationConfigWebApplicationContext(); + appContext.register(TestConfig.class); + appContext.setServletContext(context.getServletContext()); + appContext.refresh(); + + // Get the transport from Spring context + transport = appContext.getBean(WebMvcSseServerTransport.class); + + // Create DispatcherServlet with our Spring context + DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); + // dispatcherServlet.setThrowExceptionIfNoHandlerFound(true); + + // Add servlet to Tomcat and get the wrapper + var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); + wrapper.setLoadOnStartup(1); + context.addServletMappingDecoded("/*", "dispatcherServlet"); + + try { + tomcat.start(); + tomcat.getConnector(); // Create and start the connector + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + return transport; + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (transport != null) { + transport.closeGracefully().block(); + } + if (appContext != null) { + appContext.close(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java index 249b4dea..b85bed37 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java @@ -5,8 +5,7 @@ package io.modelcontextprotocol.server; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.WebMvcSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; import org.apache.catalina.Context; import org.apache.catalina.LifecycleException; import org.apache.catalina.startup.Tomcat; @@ -29,20 +28,20 @@ class WebMvcSseSyncServerTransportTests extends AbstractMcpSyncServerTests { private Tomcat tomcat; - private WebMvcSseServerTransport transport; + private WebMvcSseServerTransportProvider transportProvider; @Configuration @EnableWebMvc static class TestConfig { @Bean - public WebMvcSseServerTransport webMvcSseServerTransport() { - return new WebMvcSseServerTransport(new ObjectMapper(), MESSAGE_ENDPOINT); + public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { + return new WebMvcSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); } @Bean - public RouterFunction routerFunction(WebMvcSseServerTransport transport) { - return transport.getRouterFunction(); + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); } } @@ -50,7 +49,7 @@ public RouterFunction routerFunction(WebMvcSseServerTransport tr private AnnotationConfigWebApplicationContext appContext; @Override - protected ServerMcpTransport createMcpTransport() { + protected WebMvcSseServerTransportProvider createMcpTransportProvider() { // Set up Tomcat first tomcat = new Tomcat(); tomcat.setPort(PORT); @@ -69,7 +68,7 @@ protected ServerMcpTransport createMcpTransport() { appContext.refresh(); // Get the transport from Spring context - transport = appContext.getBean(WebMvcSseServerTransport.class); + transportProvider = appContext.getBean(WebMvcSseServerTransportProvider.class); // Create DispatcherServlet with our Spring context DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); @@ -88,7 +87,7 @@ protected ServerMcpTransport createMcpTransport() { throw new RuntimeException("Failed to start Tomcat", e); } - return transport; + return transportProvider; } @Override @@ -97,8 +96,8 @@ protected void onStart() { @Override protected void onClose() { - if (transport != null) { - transport.closeGracefully().block(); + if (transportProvider != null) { + transportProvider.closeGracefully().block(); } if (appContext != null) { appContext.close(); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java b/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java index d4e48ea7..cef3fb9f 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/MockMcpTransport.java @@ -11,19 +11,19 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.ServerMcpTransport; import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; +import io.modelcontextprotocol.spec.ServerMcpTransport; import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; /** - * A mock implementation of the {@link ClientMcpTransport} and {@link ServerMcpTransport} + * A mock implementation of the {@link McpClientTransport} and {@link ServerMcpTransport} * interfaces. */ -public class MockMcpTransport implements ClientMcpTransport, ServerMcpTransport { +public class MockMcpTransport implements McpClientTransport, ServerMcpTransport { private final Sinks.Many inbound = Sinks.many().unicast().onBackpressureBuffer(); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 02aa23d8..71356351 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -12,7 +12,7 @@ import java.util.function.Consumer; import java.util.function.Function; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; @@ -49,7 +49,7 @@ public abstract class AbstractMcpAsyncClientTests { private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!"; - abstract protected ClientMcpTransport createMcpTransport(); + abstract protected McpClientTransport createMcpTransport(); protected void onStart() { } @@ -65,11 +65,11 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } - McpAsyncClient client(ClientMcpTransport transport) { + McpAsyncClient client(McpClientTransport transport) { return client(transport, Function.identity()); } - McpAsyncClient client(ClientMcpTransport transport, Function customizer) { + McpAsyncClient client(McpClientTransport transport, Function customizer) { AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { @@ -84,11 +84,11 @@ McpAsyncClient client(ClientMcpTransport transport, Function c) { + void withClient(McpClientTransport transport, Consumer c) { withClient(transport, Function.identity(), c); } - void withClient(ClientMcpTransport transport, Function customizer, + void withClient(McpClientTransport transport, Function customizer, Consumer c) { var client = client(transport, customizer); try { diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 191de23b..128441f8 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -11,7 +11,7 @@ import java.util.function.Consumer; import java.util.function.Function; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; @@ -49,7 +49,7 @@ public abstract class AbstractMcpSyncClientTests { private static final String TEST_MESSAGE = "Hello MCP Spring AI!"; - abstract protected ClientMcpTransport createMcpTransport(); + abstract protected McpClientTransport createMcpTransport(); protected void onStart() { } @@ -65,11 +65,11 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } - McpSyncClient client(ClientMcpTransport transport) { + McpSyncClient client(McpClientTransport transport) { return client(transport, Function.identity()); } - McpSyncClient client(ClientMcpTransport transport, Function customizer) { + McpSyncClient client(McpClientTransport transport, Function customizer) { AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { @@ -84,11 +84,11 @@ McpSyncClient client(ClientMcpTransport transport, Function c) { + void withClient(McpClientTransport transport, Consumer c) { withClient(transport, Function.identity(), c); } - void withClient(ClientMcpTransport transport, Function customizer, + void withClient(McpClientTransport transport, Function customizer, Consumer c) { var client = client(transport, customizer); try { diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java new file mode 100644 index 00000000..005d78f2 --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java @@ -0,0 +1,465 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.List; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Test suite for the {@link McpAsyncServer} that can be used with different + * {@link McpTransport} implementations. + * + * @author Christian Tzolov + */ +@Deprecated +public abstract class AbstractMcpAsyncServerDeprecatedTests { + + private static final String TEST_TOOL_NAME = "test-tool"; + + private static final String TEST_RESOURCE_URI = "test://resource"; + + private static final String TEST_PROMPT_NAME = "test-prompt"; + + abstract protected ServerMcpTransport createMcpTransport(); + + protected void onStart() { + } + + protected void onClose() { + } + + @BeforeEach + void setUp() { + } + + @AfterEach + void tearDown() { + onClose(); + } + + // --------------------------------------- + // Server Lifecycle Tests + // --------------------------------------- + + @Test + void testConstructorWithInvalidArguments() { + assertThatThrownBy(() -> McpServer.async((ServerMcpTransport) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport must not be null"); + + assertThatThrownBy(() -> McpServer.async(createMcpTransport()).serverInfo((McpSchema.Implementation) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Server info must not be null"); + } + + @Test + void testGracefulShutdown() { + var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(mcpAsyncServer.closeGracefully()).verifyComplete(); + } + + @Test + void testImmediateClose() { + var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @Test + void testAddTool() { + Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(newTool, + args -> Mono.just(new CallToolResult(List.of(), false))))) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddDuplicateTool() { + Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(duplicateTool, args -> Mono.just(new CallToolResult(List.of(), false))) + .build(); + + StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(duplicateTool, + args -> Mono.just(new CallToolResult(List.of(), false))))) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); + }); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveTool() { + Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(too, args -> Mono.just(new CallToolResult(List.of(), false))) + .build(); + + StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentTool() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.removeTool("nonexistent-tool")).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class).hasMessage("Tool with name 'nonexistent-tool' not found"); + }); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testNotifyToolsListChanged() { + Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(too, args -> Mono.just(new CallToolResult(List.of(), false))) + .build(); + + StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resources Tests + // --------------------------------------- + + @Test + void testNotifyResourcesListChanged() { + var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(mcpAsyncServer.notifyResourcesListChanged()).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResource() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", + null); + McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration( + resource, req -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(mcpAsyncServer.addResource(registration)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithNullRegistration() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceRegistration) null)) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class).hasMessage("Resource must not be null"); + }); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .build(); + + Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", + null); + McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration( + resource, req -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(serverWithoutResources.addResource(registration)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with resource capabilities"); + }); + } + + @Test + void testRemoveResourceWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .build(); + + StepVerifier.create(serverWithoutResources.removeResource(TEST_RESOURCE_URI)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with resource capabilities"); + }); + } + + // --------------------------------------- + // Prompts Tests + // --------------------------------------- + + @Test + void testNotifyPromptsListChanged() { + var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(mcpAsyncServer.notifyPromptsListChanged()).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddPromptWithNullRegistration() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(false).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptRegistration) null)) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt registration must not be null"); + }); + } + + @Test + void testAddPromptWithoutCapability() { + // Create a server without prompt capabilities + McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .build(); + + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); + McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt, + req -> Mono.just(new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); + + StepVerifier.create(serverWithoutPrompts.addPrompt(registration)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with prompt capabilities"); + }); + } + + @Test + void testRemovePromptWithoutCapability() { + // Create a server without prompt capabilities + McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .build(); + + StepVerifier.create(serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with prompt capabilities"); + }); + } + + @Test + void testRemovePrompt() { + String TEST_PROMPT_NAME_TO_REMOVE = "TEST_PROMPT_NAME678"; + + Prompt prompt = new Prompt(TEST_PROMPT_NAME_TO_REMOVE, "Test Prompt", List.of()); + McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt, + req -> Mono.just(new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); + + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .prompts(registration) + .build(); + + StepVerifier.create(mcpAsyncServer.removePrompt(TEST_PROMPT_NAME_TO_REMOVE)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentPrompt() { + var mcpAsyncServer2 = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer2.removePrompt("nonexistent-prompt")).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Prompt with name 'nonexistent-prompt' not found"); + }); + + assertThatCode(() -> mcpAsyncServer2.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + + @Test + void testRootsChangeConsumers() { + // Test with single consumer + var rootsReceived = new McpSchema.Root[1]; + var consumerCalled = new boolean[1]; + + var singleConsumerServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> { + consumerCalled[0] = true; + if (!roots.isEmpty()) { + rootsReceived[0] = roots.get(0); + } + }))) + .build(); + + assertThat(singleConsumerServer).isNotNull(); + assertThatCode(() -> singleConsumerServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + onClose(); + + // Test with multiple consumers + var consumer1Called = new boolean[1]; + var consumer2Called = new boolean[1]; + var rootsContent = new List[1]; + + var multipleConsumersServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> { + consumer1Called[0] = true; + rootsContent[0] = roots; + }), roots -> Mono.fromRunnable(() -> consumer2Called[0] = true))) + .build(); + + assertThat(multipleConsumersServer).isNotNull(); + assertThatCode(() -> multipleConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + onClose(); + + // Test error handling + var errorHandlingServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .rootsChangeConsumers(List.of(roots -> { + throw new RuntimeException("Test error"); + })) + .build(); + + assertThat(errorHandlingServer).isNotNull(); + assertThatCode(() -> errorHandlingServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + onClose(); + + // Test without consumers + var noConsumersServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThat(noConsumersServer).isNotNull(); + assertThatCode(() -> noConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + + @Test + void testLoggingLevels() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().build()) + .build(); + + // Test all logging levels + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + var notification = McpSchema.LoggingMessageNotification.builder() + .level(level) + .logger("test-logger") + .data("Test message with level " + level) + .build(); + + StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); + } + } + + @Test + void testLoggingWithoutCapability() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().build()) // No logging capability + .build(); + + var notification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Test log message") + .build(); + + StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); + } + + @Test + void testLoggingWithNullNotification() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().build()) + .build(); + + StepVerifier.create(mcpAsyncServer.loggingNotification(null)).verifyError(McpError.class); + } + +} diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index ca5783d0..7bcb9a8b 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -17,8 +17,7 @@ import io.modelcontextprotocol.spec.McpSchema.Resource; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -31,10 +30,11 @@ /** * Test suite for the {@link McpAsyncServer} that can be used with different - * {@link McpTransport} implementations. + * {@link McpTransportProvider} implementations. * * @author Christian Tzolov */ +// KEEP IN SYNC with the class in mcp-test module public abstract class AbstractMcpAsyncServerTests { private static final String TEST_TOOL_NAME = "test-tool"; @@ -43,7 +43,7 @@ public abstract class AbstractMcpAsyncServerTests { private static final String TEST_PROMPT_NAME = "test-prompt"; - abstract protected ServerMcpTransport createMcpTransport(); + abstract protected McpServerTransportProvider createMcpTransportProvider(); protected void onStart() { } @@ -66,24 +66,26 @@ void tearDown() { @Test void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.async(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport must not be null"); + assertThatThrownBy(() -> McpServer.async((McpServerTransportProvider) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport provider must not be null"); - assertThatThrownBy(() -> McpServer.async(createMcpTransport()).serverInfo((McpSchema.Implementation) null)) + assertThatThrownBy( + () -> McpServer.async(createMcpTransportProvider()).serverInfo((McpSchema.Implementation) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Server info must not be null"); } @Test void testGracefulShutdown() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.closeGracefully()).verifyComplete(); } @Test void testImmediateClose() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException(); } @@ -102,13 +104,13 @@ void testImmediateClose() { @Test void testAddTool() { Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(newTool, - args -> Mono.just(new CallToolResult(List.of(), false))))) + StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, + (excnage, args) -> Mono.just(new CallToolResult(List.of(), false))))) .verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); @@ -118,14 +120,15 @@ void testAddTool() { void testAddDuplicateTool() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, args -> Mono.just(new CallToolResult(List.of(), false))) + .tool(duplicateTool, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) .build(); - StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(duplicateTool, - args -> Mono.just(new CallToolResult(List.of(), false))))) + StepVerifier + .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(duplicateTool, + (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) .verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); @@ -138,10 +141,10 @@ void testAddDuplicateTool() { void testRemoveTool() { Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, args -> Mono.just(new CallToolResult(List.of(), false))) + .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); @@ -151,7 +154,7 @@ void testRemoveTool() { @Test void testRemoveNonexistentTool() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -167,10 +170,10 @@ void testRemoveNonexistentTool() { void testNotifyToolsListChanged() { Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, args -> Mono.just(new CallToolResult(List.of(), false))) + .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); @@ -184,7 +187,7 @@ void testNotifyToolsListChanged() { @Test void testNotifyResourcesListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.notifyResourcesListChanged()).verifyComplete(); @@ -193,29 +196,29 @@ void testNotifyResourcesListChanged() { @Test void testAddResource() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); - McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration( - resource, req -> Mono.just(new ReadResourceResult(List.of()))); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); - StepVerifier.create(mcpAsyncServer.addResource(registration)).verifyComplete(); + StepVerifier.create(mcpAsyncServer.addResource(specification)).verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } @Test - void testAddResourceWithNullRegistration() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + void testAddResourceWithNullSpecification() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); - StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceRegistration) null)) + StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceSpecification) null)) .verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class).hasMessage("Resource must not be null"); }); @@ -226,16 +229,16 @@ void testAddResourceWithNullRegistration() { @Test void testAddResourceWithoutCapability() { // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport()) + McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .build(); Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); - McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration( - resource, req -> Mono.just(new ReadResourceResult(List.of()))); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); - StepVerifier.create(serverWithoutResources.addResource(registration)).verifyErrorSatisfies(error -> { + StepVerifier.create(serverWithoutResources.addResource(specification)).verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) .hasMessage("Server must be configured with resource capabilities"); }); @@ -244,7 +247,7 @@ void testAddResourceWithoutCapability() { @Test void testRemoveResourceWithoutCapability() { // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport()) + McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .build(); @@ -260,7 +263,7 @@ void testRemoveResourceWithoutCapability() { @Test void testNotifyPromptsListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.notifyPromptsListChanged()).verifyComplete(); @@ -268,31 +271,31 @@ void testNotifyPromptsListChanged() { } @Test - void testAddPromptWithNullRegistration() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + void testAddPromptWithNullSpecification() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(false).build()) .build(); - StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptRegistration) null)) + StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptSpecification) null)) .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt registration must not be null"); + assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt specification must not be null"); }); } @Test void testAddPromptWithoutCapability() { // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport()) + McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .build(); Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt, - req -> Mono.just(new GetPromptResult("Test prompt description", List + McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( + prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - StepVerifier.create(serverWithoutPrompts.addPrompt(registration)).verifyErrorSatisfies(error -> { + StepVerifier.create(serverWithoutPrompts.addPrompt(specification)).verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) .hasMessage("Server must be configured with prompt capabilities"); }); @@ -301,7 +304,7 @@ void testAddPromptWithoutCapability() { @Test void testRemovePromptWithoutCapability() { // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport()) + McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .build(); @@ -316,14 +319,14 @@ void testRemovePrompt() { String TEST_PROMPT_NAME_TO_REMOVE = "TEST_PROMPT_NAME678"; Prompt prompt = new Prompt(TEST_PROMPT_NAME_TO_REMOVE, "Test Prompt", List.of()); - McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt, - req -> Mono.just(new GetPromptResult("Test prompt description", List + McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( + prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) - .prompts(registration) + .prompts(specification) .build(); StepVerifier.create(mcpAsyncServer.removePrompt(TEST_PROMPT_NAME_TO_REMOVE)).verifyComplete(); @@ -333,7 +336,7 @@ void testRemovePrompt() { @Test void testRemoveNonexistentPrompt() { - var mcpAsyncServer2 = McpServer.async(createMcpTransport()) + var mcpAsyncServer2 = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); @@ -352,14 +355,14 @@ void testRemoveNonexistentPrompt() { // --------------------------------------- @Test - void testRootsChangeConsumers() { + void testRootsChangeHandlers() { // Test with single consumer var rootsReceived = new McpSchema.Root[1]; var consumerCalled = new boolean[1]; - var singleConsumerServer = McpServer.async(createMcpTransport()) + var singleConsumerServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> { + .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { consumerCalled[0] = true; if (!roots.isEmpty()) { rootsReceived[0] = roots.get(0); @@ -377,12 +380,12 @@ void testRootsChangeConsumers() { var consumer2Called = new boolean[1]; var rootsContent = new List[1]; - var multipleConsumersServer = McpServer.async(createMcpTransport()) + var multipleConsumersServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> { + .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { consumer1Called[0] = true; rootsContent[0] = roots; - }), roots -> Mono.fromRunnable(() -> consumer2Called[0] = true))) + }), (exchange, roots) -> Mono.fromRunnable(() -> consumer2Called[0] = true))) .build(); assertThat(multipleConsumersServer).isNotNull(); @@ -391,9 +394,9 @@ void testRootsChangeConsumers() { onClose(); // Test error handling - var errorHandlingServer = McpServer.async(createMcpTransport()) + var errorHandlingServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { + .rootsChangeHandlers(List.of((exchange, roots) -> { throw new RuntimeException("Test error"); })) .build(); @@ -404,7 +407,9 @@ void testRootsChangeConsumers() { onClose(); // Test without consumers - var noConsumersServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var noConsumersServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); assertThat(noConsumersServer).isNotNull(); assertThatCode(() -> noConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) @@ -417,7 +422,7 @@ void testRootsChangeConsumers() { @Test void testLoggingLevels() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().logging().build()) .build(); @@ -436,7 +441,7 @@ void testLoggingLevels() { @Test void testLoggingWithoutCapability() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().build()) // No logging capability .build(); @@ -452,7 +457,7 @@ void testLoggingWithoutCapability() { @Test void testLoggingWithNullNotification() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().logging().build()) .build(); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java new file mode 100644 index 00000000..c6625aca --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java @@ -0,0 +1,431 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.util.List; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Test suite for the {@link McpSyncServer} that can be used with different + * {@link McpTransport} implementations. + * + * @author Christian Tzolov + */ +public abstract class AbstractMcpSyncServerDeprecatedTests { + + private static final String TEST_TOOL_NAME = "test-tool"; + + private static final String TEST_RESOURCE_URI = "test://resource"; + + private static final String TEST_PROMPT_NAME = "test-prompt"; + + abstract protected ServerMcpTransport createMcpTransport(); + + protected void onStart() { + } + + protected void onClose() { + } + + @BeforeEach + void setUp() { + // onStart(); + } + + @AfterEach + void tearDown() { + onClose(); + } + + // --------------------------------------- + // Server Lifecycle Tests + // --------------------------------------- + + @Test + void testConstructorWithInvalidArguments() { + assertThatThrownBy(() -> McpServer.sync((ServerMcpTransport) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport must not be null"); + + assertThatThrownBy(() -> McpServer.sync(createMcpTransport()).serverInfo(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Server info must not be null"); + } + + @Test + void testGracefulShutdown() { + var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testImmediateClose() { + var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.close()).doesNotThrowAnyException(); + } + + @Test + void testGetAsyncServer() { + var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @Test + void testAddTool() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); + assertThatCode(() -> mcpSyncServer + .addTool(new McpServerFeatures.SyncToolRegistration(newTool, args -> new CallToolResult(List.of(), false)))) + .doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddDuplicateTool() { + Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(duplicateTool, args -> new CallToolResult(List.of(), false)) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolRegistration(duplicateTool, + args -> new CallToolResult(List.of(), false)))) + .isInstanceOf(McpError.class) + .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testRemoveTool() { + Tool tool = new McpSchema.Tool(TEST_TOOL_NAME, "Test tool", emptyJsonSchema); + + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(tool, args -> new CallToolResult(List.of(), false)) + .build(); + + assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentTool() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.removeTool("nonexistent-tool")).isInstanceOf(McpError.class) + .hasMessage("Tool with name 'nonexistent-tool' not found"); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testNotifyToolsListChanged() { + var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.notifyToolsListChanged()).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resources Tests + // --------------------------------------- + + @Test + void testNotifyResourcesListChanged() { + var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.notifyResourcesListChanged()).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddResource() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", + null); + McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration( + resource, req -> new ReadResourceResult(List.of())); + + assertThatCode(() -> mcpSyncServer.addResource(registration)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithNullRegistration() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceRegistration) null)) + .isInstanceOf(McpError.class) + .hasMessage("Resource must not be null"); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithoutCapability() { + var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", + null); + McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration( + resource, req -> new ReadResourceResult(List.of())); + + assertThatThrownBy(() -> serverWithoutResources.addResource(registration)).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with resource capabilities"); + } + + @Test + void testRemoveResourceWithoutCapability() { + var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with resource capabilities"); + } + + // --------------------------------------- + // Prompts Tests + // --------------------------------------- + + @Test + void testNotifyPromptsListChanged() { + var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.notifyPromptsListChanged()).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddPromptWithNullRegistration() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(false).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptRegistration) null)) + .isInstanceOf(McpError.class) + .hasMessage("Prompt registration must not be null"); + } + + @Test + void testAddPromptWithoutCapability() { + var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); + McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt, + req -> new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); + + assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(registration)).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with prompt capabilities"); + } + + @Test + void testRemovePromptWithoutCapability() { + var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with prompt capabilities"); + } + + @Test + void testRemovePrompt() { + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); + McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt, + req -> new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); + + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .prompts(registration) + .build(); + + assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentPrompt() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.removePrompt("nonexistent-prompt")).isInstanceOf(McpError.class) + .hasMessage("Prompt with name 'nonexistent-prompt' not found"); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + + @Test + void testRootsChangeConsumers() { + // Test with single consumer + var rootsReceived = new McpSchema.Root[1]; + var consumerCalled = new boolean[1]; + + var singleConsumerServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .rootsChangeConsumers(List.of(roots -> { + consumerCalled[0] = true; + if (!roots.isEmpty()) { + rootsReceived[0] = roots.get(0); + } + })) + .build(); + + assertThat(singleConsumerServer).isNotNull(); + assertThatCode(() -> singleConsumerServer.closeGracefully()).doesNotThrowAnyException(); + onClose(); + + // Test with multiple consumers + var consumer1Called = new boolean[1]; + var consumer2Called = new boolean[1]; + var rootsContent = new List[1]; + + var multipleConsumersServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .rootsChangeConsumers(List.of(roots -> { + consumer1Called[0] = true; + rootsContent[0] = roots; + }, roots -> consumer2Called[0] = true)) + .build(); + + assertThat(multipleConsumersServer).isNotNull(); + assertThatCode(() -> multipleConsumersServer.closeGracefully()).doesNotThrowAnyException(); + onClose(); + + // Test error handling + var errorHandlingServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .rootsChangeConsumers(List.of(roots -> { + throw new RuntimeException("Test error"); + })) + .build(); + + assertThat(errorHandlingServer).isNotNull(); + assertThatCode(() -> errorHandlingServer.closeGracefully()).doesNotThrowAnyException(); + onClose(); + + // Test without consumers + var noConsumersServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThat(noConsumersServer).isNotNull(); + assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + + @Test + void testLoggingLevels() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().build()) + .build(); + + // Test all logging levels + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + var notification = McpSchema.LoggingMessageNotification.builder() + .level(level) + .logger("test-logger") + .data("Test message with level " + level) + .build(); + + assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); + } + } + + @Test + void testLoggingWithoutCapability() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().build()) // No logging capability + .build(); + + var notification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Test log message") + .build(); + + assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); + } + + @Test + void testLoggingWithNullNotification() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.loggingNotification(null)).isInstanceOf(McpError.class); + } + +} diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index f8b95750..7846e053 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -16,8 +16,7 @@ import io.modelcontextprotocol.spec.McpSchema.Resource; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -28,10 +27,11 @@ /** * Test suite for the {@link McpSyncServer} that can be used with different - * {@link McpTransport} implementations. + * {@link McpTransportProvider} implementations. * * @author Christian Tzolov */ +// KEEP IN SYNC with the class in mcp-test module public abstract class AbstractMcpSyncServerTests { private static final String TEST_TOOL_NAME = "test-tool"; @@ -40,7 +40,7 @@ public abstract class AbstractMcpSyncServerTests { private static final String TEST_PROMPT_NAME = "test-prompt"; - abstract protected ServerMcpTransport createMcpTransport(); + abstract protected McpServerTransportProvider createMcpTransportProvider(); protected void onStart() { } @@ -64,31 +64,32 @@ void tearDown() { @Test void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.sync(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport must not be null"); + assertThatThrownBy(() -> McpServer.sync((McpServerTransportProvider) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport provider must not be null"); - assertThatThrownBy(() -> McpServer.sync(createMcpTransport()).serverInfo(null)) + assertThatThrownBy(() -> McpServer.sync(createMcpTransportProvider()).serverInfo(null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Server info must not be null"); } @Test void testGracefulShutdown() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); } @Test void testImmediateClose() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.close()).doesNotThrowAnyException(); } @Test void testGetAsyncServer() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); @@ -109,14 +110,14 @@ void testGetAsyncServer() { @Test void testAddTool() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - assertThatCode(() -> mcpSyncServer - .addTool(new McpServerFeatures.SyncToolRegistration(newTool, args -> new CallToolResult(List.of(), false)))) + assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(newTool, + (exchange, args) -> new CallToolResult(List.of(), false)))) .doesNotThrowAnyException(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); @@ -126,14 +127,14 @@ void testAddTool() { void testAddDuplicateTool() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, args -> new CallToolResult(List.of(), false)) + .tool(duplicateTool, (exchange, args) -> new CallToolResult(List.of(), false)) .build(); - assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolRegistration(duplicateTool, - args -> new CallToolResult(List.of(), false)))) + assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool, + (exchange, args) -> new CallToolResult(List.of(), false)))) .isInstanceOf(McpError.class) .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); @@ -144,10 +145,10 @@ void testAddDuplicateTool() { void testRemoveTool() { Tool tool = new McpSchema.Tool(TEST_TOOL_NAME, "Test tool", emptyJsonSchema); - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(tool, args -> new CallToolResult(List.of(), false)) + .tool(tool, (exchange, args) -> new CallToolResult(List.of(), false)) .build(); assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); @@ -157,7 +158,7 @@ void testRemoveTool() { @Test void testRemoveNonexistentTool() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -170,7 +171,7 @@ void testRemoveNonexistentTool() { @Test void testNotifyToolsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.notifyToolsListChanged()).doesNotThrowAnyException(); @@ -183,7 +184,7 @@ void testNotifyToolsListChanged() { @Test void testNotifyResourcesListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.notifyResourcesListChanged()).doesNotThrowAnyException(); @@ -192,29 +193,29 @@ void testNotifyResourcesListChanged() { @Test void testAddResource() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); - McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration( - resource, req -> new ReadResourceResult(List.of())); + McpServerFeatures.SyncResourceSpecification specificaiton = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); - assertThatCode(() -> mcpSyncServer.addResource(registration)).doesNotThrowAnyException(); + assertThatCode(() -> mcpSyncServer.addResource(specificaiton)).doesNotThrowAnyException(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); } @Test - void testAddResourceWithNullRegistration() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + void testAddResourceWithNullSpecifiation() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceRegistration) null)) + assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceSpecification) null)) .isInstanceOf(McpError.class) .hasMessage("Resource must not be null"); @@ -223,20 +224,24 @@ void testAddResourceWithNullRegistration() { @Test void testAddResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var serverWithoutResources = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); - McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration( - resource, req -> new ReadResourceResult(List.of())); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); - assertThatThrownBy(() -> serverWithoutResources.addResource(registration)).isInstanceOf(McpError.class) + assertThatThrownBy(() -> serverWithoutResources.addResource(specification)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with resource capabilities"); } @Test void testRemoveResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var serverWithoutResources = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with resource capabilities"); @@ -248,7 +253,7 @@ void testRemoveResourceWithoutCapability() { @Test void testNotifyPromptsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.notifyPromptsListChanged()).doesNotThrowAnyException(); @@ -256,33 +261,37 @@ void testNotifyPromptsListChanged() { } @Test - void testAddPromptWithNullRegistration() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + void testAddPromptWithNullSpecification() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(false).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptRegistration) null)) + assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptSpecification) null)) .isInstanceOf(McpError.class) - .hasMessage("Prompt registration must not be null"); + .hasMessage("Prompt specification must not be null"); } @Test void testAddPromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt, - req -> new GetPromptResult("Test prompt description", List + McpServerFeatures.SyncPromptSpecification specificaiton = new McpServerFeatures.SyncPromptSpecification(prompt, + (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(registration)).isInstanceOf(McpError.class) + assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specificaiton)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with prompt capabilities"); } @Test void testRemovePromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with prompt capabilities"); @@ -291,14 +300,14 @@ void testRemovePromptWithoutCapability() { @Test void testRemovePrompt() { Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt, - req -> new GetPromptResult("Test prompt description", List + McpServerFeatures.SyncPromptSpecification specificaiton = new McpServerFeatures.SyncPromptSpecification(prompt, + (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) - .prompts(registration) + .prompts(specificaiton) .build(); assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); @@ -308,7 +317,7 @@ void testRemovePrompt() { @Test void testRemoveNonexistentPrompt() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); @@ -324,14 +333,14 @@ void testRemoveNonexistentPrompt() { // --------------------------------------- @Test - void testRootsChangeConsumers() { + void testRootsChangeHandlers() { // Test with single consumer var rootsReceived = new McpSchema.Root[1]; var consumerCalled = new boolean[1]; - var singleConsumerServer = McpServer.sync(createMcpTransport()) + var singleConsumerServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { + .rootsChangeHandlers(List.of((exchage, roots) -> { consumerCalled[0] = true; if (!roots.isEmpty()) { rootsReceived[0] = roots.get(0); @@ -348,12 +357,12 @@ void testRootsChangeConsumers() { var consumer2Called = new boolean[1]; var rootsContent = new List[1]; - var multipleConsumersServer = McpServer.sync(createMcpTransport()) + var multipleConsumersServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { + .rootsChangeHandlers(List.of((exchange, roots) -> { consumer1Called[0] = true; rootsContent[0] = roots; - }, roots -> consumer2Called[0] = true)) + }, (exchange, roots) -> consumer2Called[0] = true)) .build(); assertThat(multipleConsumersServer).isNotNull(); @@ -361,9 +370,9 @@ void testRootsChangeConsumers() { onClose(); // Test error handling - var errorHandlingServer = McpServer.sync(createMcpTransport()) + var errorHandlingServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { + .rootsChangeHandlers(List.of((exchange, roots) -> { throw new RuntimeException("Test error"); })) .build(); @@ -373,7 +382,7 @@ void testRootsChangeConsumers() { onClose(); // Test without consumers - var noConsumersServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var noConsumersServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThat(noConsumersServer).isNotNull(); assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); @@ -385,7 +394,7 @@ void testRootsChangeConsumers() { @Test void testLoggingLevels() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().logging().build()) .build(); @@ -404,7 +413,7 @@ void testLoggingLevels() { @Test void testLoggingWithoutCapability() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().build()) // No logging capability .build(); @@ -420,7 +429,7 @@ void testLoggingWithoutCapability() { @Test void testLoggingWithNullNotification() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().logging().build()) .build(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index 278e360d..9cbef050 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -15,9 +15,9 @@ import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.spec.ClientMcpTransport; -import io.modelcontextprotocol.spec.DefaultMcpSession; -import io.modelcontextprotocol.spec.DefaultMcpSession.NotificationHandler; -import io.modelcontextprotocol.spec.DefaultMcpSession.RequestHandler; +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpClientSession.NotificationHandler; +import io.modelcontextprotocol.spec.McpClientSession.RequestHandler; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; @@ -73,7 +73,7 @@ * @author Christian Tzolov * @see McpClient * @see McpSchema - * @see DefaultMcpSession + * @see McpClientSession */ public class McpAsyncClient { @@ -95,7 +95,7 @@ public class McpAsyncClient { * The MCP session implementation that manages bidirectional JSON-RPC communication * between clients and servers. */ - private final DefaultMcpSession mcpSession; + private final McpClientSession mcpSession; /** * Client capabilities. @@ -228,7 +228,7 @@ public class McpAsyncClient { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE, asyncLoggingNotificationHandler(loggingConsumersFinal)); - this.mcpSession = new DefaultMcpSession(requestTimeout, transport, requestHandlers, notificationHandlers); + this.mcpSession = new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java index fa2690dc..9c5f7b01 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -13,6 +13,7 @@ import java.util.function.Function; import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpTransport; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; @@ -113,11 +114,31 @@ public interface McpClient { * and {@code SseClientTransport} for SSE-based communication. * @return A new builder instance for configuring the client * @throws IllegalArgumentException if transport is null + * @deprecated This method will be removed in 0.9.0. Use + * {@link #sync(McpClientTransport)} */ + @Deprecated static SyncSpec sync(ClientMcpTransport transport) { return new SyncSpec(transport); } + /** + * Start building a synchronous MCP client with the specified transport layer. The + * synchronous MCP client provides blocking operations. Synchronous clients wait for + * each operation to complete before returning, making them simpler to use but + * potentially less performant for concurrent operations. The transport layer handles + * the low-level communication between client and server using protocols like stdio or + * Server-Sent Events (SSE). + * @param transport The transport layer implementation for MCP communication. Common + * implementations include {@code StdioClientTransport} for stdio-based communication + * and {@code SseClientTransport} for SSE-based communication. + * @return A new builder instance for configuring the client + * @throws IllegalArgumentException if transport is null + */ + static SyncSpec sync(McpClientTransport transport) { + return new SyncSpec(transport); + } + /** * Start building an asynchronous MCP client with the specified transport layer. The * asynchronous MCP client provides non-blocking operations. Asynchronous clients @@ -130,11 +151,31 @@ static SyncSpec sync(ClientMcpTransport transport) { * and {@code SseClientTransport} for SSE-based communication. * @return A new builder instance for configuring the client * @throws IllegalArgumentException if transport is null + * @deprecated This method will be removed in 0.9.0. Use + * {@link #async(McpClientTransport)} */ + @Deprecated static AsyncSpec async(ClientMcpTransport transport) { return new AsyncSpec(transport); } + /** + * Start building an asynchronous MCP client with the specified transport layer. The + * asynchronous MCP client provides non-blocking operations. Asynchronous clients + * return reactive primitives (Mono/Flux) immediately, allowing for concurrent + * operations and reactive programming patterns. The transport layer handles the + * low-level communication between client and server using protocols like stdio or + * Server-Sent Events (SSE). + * @param transport The transport layer implementation for MCP communication. Common + * implementations include {@code StdioClientTransport} for stdio-based communication + * and {@code SseClientTransport} for SSE-based communication. + * @return A new builder instance for configuring the client + * @throws IllegalArgumentException if transport is null + */ + static AsyncSpec async(McpClientTransport transport) { + return new AsyncSpec(transport); + } + /** * Synchronous client specification. This class follows the builder pattern to provide * a fluent API for setting up clients with custom configurations. diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index e5d964b7..ec0a0dfd 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -6,7 +6,7 @@ import java.time.Duration; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; @@ -66,7 +66,8 @@ public class McpSyncClient implements AutoCloseable { * Create a new McpSyncClient with the given delegate. * @param delegate the asynchronous kernel on top of which this synchronous client * provides a blocking API. - * @deprecated Use {@link McpClient#sync(ClientMcpTransport)} to obtain an instance. + * @deprecated This method will be removed in 0.9.0. Use + * {@link McpClient#sync(McpClientTransport)} to obtain an instance. */ @Deprecated // TODO make the constructor package private post-deprecation diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 35da5197..ca1b0e87 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -3,18 +3,6 @@ */ package io.modelcontextprotocol.client.transport; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.client.transport.FlowSseClient.SseEvent; -import io.modelcontextprotocol.spec.ClientMcpTransport; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; -import io.modelcontextprotocol.util.Assert; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Mono; - import java.io.IOException; import java.net.URI; import java.net.http.HttpClient; @@ -27,6 +15,18 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.transport.FlowSseClient.SseEvent; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + /** * Server-Sent Events (SSE) implementation of the * {@link io.modelcontextprotocol.spec.McpTransport} that follows the MCP HTTP with SSE @@ -52,9 +52,9 @@ * * @author Christian Tzolov * @see io.modelcontextprotocol.spec.McpTransport - * @see io.modelcontextprotocol.spec.ClientMcpTransport + * @see io.modelcontextprotocol.spec.McpClientTransport */ -public class HttpClientSseClientTransport implements ClientMcpTransport { +public class HttpClientSseClientTransport implements McpClientTransport { private static final Logger logger = LoggerFactory.getLogger(HttpClientSseClientTransport.class); diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java index d35db3f8..f9a97849 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java @@ -11,14 +11,13 @@ import java.time.Duration; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executors; import java.util.function.Consumer; import java.util.function.Function; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; import io.modelcontextprotocol.util.Assert; @@ -38,7 +37,7 @@ * @author Christian Tzolov * @author Dariusz Jędrzejczyk */ -public class StdioClientTransport implements ClientMcpTransport { +public class StdioClientTransport implements McpClientTransport { private static final Logger logger = LoggerFactory.getLogger(StdioClientTransport.class); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 7b691678..07a9f154 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -9,21 +9,25 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; +import java.util.function.BiFunction; import java.util.function.Function; import com.fasterxml.jackson.core.type.TypeReference; -import io.modelcontextprotocol.spec.DefaultMcpSession; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpClientSession; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.ServerMcpTransport; -import io.modelcontextprotocol.spec.DefaultMcpSession.NotificationHandler; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.ServerMcpTransport; import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -69,143 +73,41 @@ * @author Dariusz Jędrzejczyk * @see McpServer * @see McpSchema - * @see DefaultMcpSession + * @see McpClientSession */ public class McpAsyncServer { private static final Logger logger = LoggerFactory.getLogger(McpAsyncServer.class); - /** - * The MCP session implementation that manages bidirectional JSON-RPC communication - * between clients and servers. - */ - private final DefaultMcpSession mcpSession; - - private final ServerMcpTransport transport; - - private final McpSchema.ServerCapabilities serverCapabilities; - - private final McpSchema.Implementation serverInfo; - - private McpSchema.ClientCapabilities clientCapabilities; - - private McpSchema.Implementation clientInfo; - - /** - * Thread-safe list of tool handlers that can be modified at runtime. - */ - private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); - - private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); - - private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); - - private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); - - private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; + private final McpAsyncServer delegate; - /** - * Supported protocol versions. - */ - private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); + McpAsyncServer() { + this.delegate = null; + } /** * Create a new McpAsyncServer with the given transport and capabilities. * @param mcpTransport The transport layer implementation for MCP communication. * @param features The MCP server supported features. + * @deprecated This constructor will beremoved in 0.9.0. Use + * {@link #McpAsyncServer(McpServerTransportProvider, ObjectMapper, McpServerFeatures.Async)} + * instead. */ + @Deprecated McpAsyncServer(ServerMcpTransport mcpTransport, McpServerFeatures.Async features) { - - this.serverInfo = features.serverInfo(); - this.serverCapabilities = features.serverCapabilities(); - this.tools.addAll(features.tools()); - this.resources.putAll(features.resources()); - this.resourceTemplates.addAll(features.resourceTemplates()); - this.prompts.putAll(features.prompts()); - - Map> requestHandlers = new HashMap<>(); - - // Initialize request handlers for standard MCP methods - requestHandlers.put(McpSchema.METHOD_INITIALIZE, asyncInitializeRequestHandler()); - - // Ping MUST respond with an empty data, but not NULL response. - requestHandlers.put(McpSchema.METHOD_PING, (params) -> Mono.just("")); - - // Add tools API handlers if the tool capability is enabled - if (this.serverCapabilities.tools() != null) { - requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); - } - - // Add resources API handlers if provided - if (this.serverCapabilities.resources() != null) { - requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); - requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); - } - - // Add prompts API handlers if provider exists - if (this.serverCapabilities.prompts() != null) { - requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); - requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); - } - - // Add logging API handlers if the logging capability is enabled - if (this.serverCapabilities.logging() != null) { - requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); - } - - Map notificationHandlers = new HashMap<>(); - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (params) -> Mono.empty()); - - List, Mono>> rootsChangeConsumers = features.rootsChangeConsumers(); - - if (Utils.isEmpty(rootsChangeConsumers)) { - rootsChangeConsumers = List.of((roots) -> Mono.fromRunnable(() -> logger - .warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); - } - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, - asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - - this.transport = mcpTransport; - this.mcpSession = new DefaultMcpSession(Duration.ofSeconds(10), mcpTransport, requestHandlers, - notificationHandlers); + this.delegate = new LegacyAsyncServer(mcpTransport, features); } - // --------------------------------------- - // Lifecycle Management - // --------------------------------------- - private DefaultMcpSession.RequestHandler asyncInitializeRequestHandler() { - return params -> { - McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(params, - new TypeReference() { - }); - this.clientCapabilities = initializeRequest.capabilities(); - this.clientInfo = initializeRequest.clientInfo(); - logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", - initializeRequest.protocolVersion(), initializeRequest.capabilities(), - initializeRequest.clientInfo()); - - // The server MUST respond with the highest protocol version it supports if - // it does not support the requested (e.g. Client) version. - String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); - - if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { - // If the server supports the requested protocol version, it MUST respond - // with the same version. - serverProtocolVersion = initializeRequest.protocolVersion(); - } - else { - logger.warn( - "Client requested unsupported protocol version: {}, so the server will sugggest the {} version instead", - initializeRequest.protocolVersion(), serverProtocolVersion); - } - - return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, - this.serverInfo, null)); - }; + /** + * Create a new McpAsyncServer with the given transport provider and capabilities. + * @param mcpTransportProvider The transport layer implementation for MCP + * communication. + * @param features The MCP server supported features. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + */ + McpAsyncServer(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, + McpServerFeatures.Async features) { + this.delegate = new AsyncServerImpl(mcpTransportProvider, objectMapper, features); } /** @@ -213,7 +115,7 @@ private DefaultMcpSession.RequestHandler asyncInitia * @return The server capabilities */ public McpSchema.ServerCapabilities getServerCapabilities() { - return this.serverCapabilities; + return this.delegate.getServerCapabilities(); } /** @@ -221,23 +123,29 @@ public McpSchema.ServerCapabilities getServerCapabilities() { * @return The server implementation details */ public McpSchema.Implementation getServerInfo() { - return this.serverInfo; + return this.delegate.getServerInfo(); } /** * Get the client capabilities that define the supported features and functionality. * @return The client capabilities + * @deprecated This will be removed in 0.9.0. Use + * {@link McpAsyncServerExchange#getClientCapabilities()}. */ + @Deprecated public ClientCapabilities getClientCapabilities() { - return this.clientCapabilities; + return this.delegate.getClientCapabilities(); } /** * Get the client implementation information. * @return The client implementation details + * @deprecated This will be removed in 0.9.0. Use + * {@link McpAsyncServerExchange#getClientInfo()}. */ + @Deprecated public McpSchema.Implementation getClientInfo() { - return this.clientInfo; + return this.delegate.getClientInfo(); } /** @@ -245,46 +153,37 @@ public McpSchema.Implementation getClientInfo() { * @return A Mono that completes when the server has been closed */ public Mono closeGracefully() { - return this.mcpSession.closeGracefully(); + return this.delegate.closeGracefully(); } /** * Close the server immediately. */ public void close() { - this.mcpSession.close(); + this.delegate.close(); } - private static final TypeReference LIST_ROOTS_RESULT_TYPE_REF = new TypeReference<>() { - }; - /** * Retrieves the list of all roots provided by the client. * @return A Mono that emits the list of roots result. + * @deprecated This will be removed in 0.9.0. Use + * {@link McpAsyncServerExchange#listRoots()}. */ + @Deprecated public Mono listRoots() { - return this.listRoots(null); + return this.delegate.listRoots(null); } /** * Retrieves a paginated list of roots provided by the server. * @param cursor Optional pagination cursor from a previous list request * @return A Mono that emits the list of roots result containing + * @deprecated This will be removed in 0.9.0. Use + * {@link McpAsyncServerExchange#listRoots(String)}. */ + @Deprecated public Mono listRoots(String cursor) { - return this.mcpSession.sendRequest(McpSchema.METHOD_ROOTS_LIST, new McpSchema.PaginatedRequest(cursor), - LIST_ROOTS_RESULT_TYPE_REF); - } - - private NotificationHandler asyncRootsListChangedNotificationHandler( - List, Mono>> rootsChangeConsumers) { - return params -> listRoots().flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) - .flatMap(consumer -> consumer.apply(listRootsResult.roots())) - .onErrorResume(error -> { - logger.error("Error handling roots list change notification", error); - return Mono.empty(); - }) - .then()); + return this.delegate.listRoots(cursor); } // --------------------------------------- @@ -295,36 +194,21 @@ private NotificationHandler asyncRootsListChangedNotificationHandler( * Add a new tool registration at runtime. * @param toolRegistration The tool registration to add * @return Mono that completes when clients have been notified of the change + * @deprecated This method will be removed in 0.9.0. Use + * {@link #addTool(McpServerFeatures.AsyncToolSpecification)}. */ + @Deprecated public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistration) { - if (toolRegistration == null) { - return Mono.error(new McpError("Tool registration must not be null")); - } - if (toolRegistration.tool() == null) { - return Mono.error(new McpError("Tool must not be null")); - } - if (toolRegistration.call() == null) { - return Mono.error(new McpError("Tool call handler must not be null")); - } - if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); - } - - return Mono.defer(() -> { - // Check for duplicate tool names - if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolRegistration.tool().name()))) { - return Mono - .error(new McpError("Tool with name '" + toolRegistration.tool().name() + "' already exists")); - } - - this.tools.add(toolRegistration); - logger.debug("Added tool handler: {}", toolRegistration.tool().name()); + return this.delegate.addTool(toolRegistration); + } - if (this.serverCapabilities.tools().listChanged()) { - return notifyToolsListChanged(); - } - return Mono.empty(); - }); + /** + * Add a new tool specification at runtime. + * @param toolSpecification The tool specification to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) { + return this.delegate.addTool(toolSpecification); } /** @@ -333,24 +217,7 @@ public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistrati * @return Mono that completes when clients have been notified of the change */ public Mono removeTool(String toolName) { - if (toolName == null) { - return Mono.error(new McpError("Tool name must not be null")); - } - if (this.serverCapabilities.tools() == null) { - return Mono.error(new McpError("Server must be configured with tool capabilities")); - } - - return Mono.defer(() -> { - boolean removed = this.tools.removeIf(toolRegistration -> toolRegistration.tool().name().equals(toolName)); - if (removed) { - logger.debug("Removed tool handler: {}", toolName); - if (this.serverCapabilities.tools().listChanged()) { - return notifyToolsListChanged(); - } - return Mono.empty(); - } - return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); - }); + return this.delegate.removeTool(toolName); } /** @@ -358,34 +225,7 @@ public Mono removeTool(String toolName) { * @return A Mono that completes when all clients have been notified */ public Mono notifyToolsListChanged() { - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); - } - - private DefaultMcpSession.RequestHandler toolsListRequestHandler() { - return params -> { - List tools = this.tools.stream().map(McpServerFeatures.AsyncToolRegistration::tool).toList(); - - return Mono.just(new McpSchema.ListToolsResult(tools, null)); - }; - } - - private DefaultMcpSession.RequestHandler toolsCallRequestHandler() { - return params -> { - McpSchema.CallToolRequest callToolRequest = transport.unmarshalFrom(params, - new TypeReference() { - }); - - Optional toolRegistration = this.tools.stream() - .filter(tr -> callToolRequest.name().equals(tr.tool().name())) - .findAny(); - - if (toolRegistration.isEmpty()) { - return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); - } - - return toolRegistration.map(tool -> tool.call().apply(callToolRequest.arguments())) - .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); - }; + return this.delegate.notifyToolsListChanged(); } // --------------------------------------- @@ -396,27 +236,21 @@ private DefaultMcpSession.RequestHandler toolsCallRequestHandler * Add a new resource handler at runtime. * @param resourceHandler The resource handler to add * @return Mono that completes when clients have been notified of the change + * @deprecated This method will be removed in 0.9.0. Use + * {@link #addResource(McpServerFeatures.AsyncResourceSpecification)}. */ + @Deprecated public Mono addResource(McpServerFeatures.AsyncResourceRegistration resourceHandler) { - if (resourceHandler == null || resourceHandler.resource() == null) { - return Mono.error(new McpError("Resource must not be null")); - } - - if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); - } + return this.delegate.addResource(resourceHandler); + } - return Mono.defer(() -> { - if (this.resources.putIfAbsent(resourceHandler.resource().uri(), resourceHandler) != null) { - return Mono - .error(new McpError("Resource with URI '" + resourceHandler.resource().uri() + "' already exists")); - } - logger.debug("Added resource handler: {}", resourceHandler.resource().uri()); - if (this.serverCapabilities.resources().listChanged()) { - return notifyResourcesListChanged(); - } - return Mono.empty(); - }); + /** + * Add a new resource handler at runtime. + * @param resourceHandler The resource handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceHandler) { + return this.delegate.addResource(resourceHandler); } /** @@ -425,24 +259,7 @@ public Mono addResource(McpServerFeatures.AsyncResourceRegistration resour * @return Mono that completes when clients have been notified of the change */ public Mono removeResource(String resourceUri) { - if (resourceUri == null) { - return Mono.error(new McpError("Resource URI must not be null")); - } - if (this.serverCapabilities.resources() == null) { - return Mono.error(new McpError("Server must be configured with resource capabilities")); - } - - return Mono.defer(() -> { - McpServerFeatures.AsyncResourceRegistration removed = this.resources.remove(resourceUri); - if (removed != null) { - logger.debug("Removed resource handler: {}", resourceUri); - if (this.serverCapabilities.resources().listChanged()) { - return notifyResourcesListChanged(); - } - return Mono.empty(); - } - return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); - }); + return this.delegate.removeResource(resourceUri); } /** @@ -450,36 +267,7 @@ public Mono removeResource(String resourceUri) { * @return A Mono that completes when all clients have been notified */ public Mono notifyResourcesListChanged() { - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); - } - - private DefaultMcpSession.RequestHandler resourcesListRequestHandler() { - return params -> { - var resourceList = this.resources.values() - .stream() - .map(McpServerFeatures.AsyncResourceRegistration::resource) - .toList(); - return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); - }; - } - - private DefaultMcpSession.RequestHandler resourceTemplateListRequestHandler() { - return params -> Mono.just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null)); - - } - - private DefaultMcpSession.RequestHandler resourcesReadRequestHandler() { - return params -> { - McpSchema.ReadResourceRequest resourceRequest = transport.unmarshalFrom(params, - new TypeReference() { - }); - var resourceUri = resourceRequest.uri(); - McpServerFeatures.AsyncResourceRegistration registration = this.resources.get(resourceUri); - if (registration != null) { - return registration.readHandler().apply(resourceRequest); - } - return Mono.error(new McpError("Resource not found: " + resourceUri)); - }; + return this.delegate.notifyResourcesListChanged(); } // --------------------------------------- @@ -490,33 +278,21 @@ private DefaultMcpSession.RequestHandler resources * Add a new prompt handler at runtime. * @param promptRegistration The prompt handler to add * @return Mono that completes when clients have been notified of the change + * @deprecated This method will be removed in 0.9.0. Use + * {@link #addPrompt(McpServerFeatures.AsyncPromptSpecification)}. */ + @Deprecated public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegistration) { - if (promptRegistration == null) { - return Mono.error(new McpError("Prompt registration must not be null")); - } - if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); - } - - return Mono.defer(() -> { - McpServerFeatures.AsyncPromptRegistration registration = this.prompts - .putIfAbsent(promptRegistration.prompt().name(), promptRegistration); - if (registration != null) { - return Mono.error( - new McpError("Prompt with name '" + promptRegistration.prompt().name() + "' already exists")); - } - - logger.debug("Added prompt handler: {}", promptRegistration.prompt().name()); + return this.delegate.addPrompt(promptRegistration); + } - // Servers that declared the listChanged capability SHOULD send a - // notification, - // when the list of available prompts changes - if (this.serverCapabilities.prompts().listChanged()) { - return notifyPromptsListChanged(); - } - return Mono.empty(); - }); + /** + * Add a new prompt handler at runtime. + * @param promptSpecification The prompt handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) { + return this.delegate.addPrompt(promptSpecification); } /** @@ -525,27 +301,7 @@ public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegi * @return Mono that completes when clients have been notified of the change */ public Mono removePrompt(String promptName) { - if (promptName == null) { - return Mono.error(new McpError("Prompt name must not be null")); - } - if (this.serverCapabilities.prompts() == null) { - return Mono.error(new McpError("Server must be configured with prompt capabilities")); - } - - return Mono.defer(() -> { - McpServerFeatures.AsyncPromptRegistration removed = this.prompts.remove(promptName); - - if (removed != null) { - logger.debug("Removed prompt handler: {}", promptName); - // Servers that declared the listChanged capability SHOULD send a - // notification, when the list of available prompts changes - if (this.serverCapabilities.prompts().listChanged()) { - return this.notifyPromptsListChanged(); - } - return Mono.empty(); - } - return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); - }); + return this.delegate.removePrompt(promptName); } /** @@ -553,39 +309,7 @@ public Mono removePrompt(String promptName) { * @return A Mono that completes when all clients have been notified */ public Mono notifyPromptsListChanged() { - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); - } - - private DefaultMcpSession.RequestHandler promptsListRequestHandler() { - return params -> { - // TODO: Implement pagination - // McpSchema.PaginatedRequest request = transport.unmarshalFrom(params, - // new TypeReference() { - // }); - - var promptList = this.prompts.values() - .stream() - .map(McpServerFeatures.AsyncPromptRegistration::prompt) - .toList(); - - return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); - }; - } - - private DefaultMcpSession.RequestHandler promptsGetRequestHandler() { - return params -> { - McpSchema.GetPromptRequest promptRequest = transport.unmarshalFrom(params, - new TypeReference() { - }); - - // Implement prompt retrieval logic here - McpServerFeatures.AsyncPromptRegistration registration = this.prompts.get(promptRequest.name()); - if (registration == null) { - return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); - } - - return registration.promptHandler().apply(promptRequest); - }; + return this.delegate.notifyPromptsListChanged(); } // --------------------------------------- @@ -599,41 +323,12 @@ private DefaultMcpSession.RequestHandler promptsGetRe * @return A Mono that completes when the notification has been sent */ public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { - - if (loggingMessageNotification == null) { - return Mono.error(new McpError("Logging message must not be null")); - } - - Map params = this.transport.unmarshalFrom(loggingMessageNotification, - new TypeReference>() { - }); - - if (loggingMessageNotification.level().level() < minLoggingLevel.level()) { - return Mono.empty(); - } - - return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_MESSAGE, params); - } - - /** - * Handles requests to set the minimum logging level. Messages below this level will - * not be sent. - * @return A handler that processes logging level change requests - */ - private DefaultMcpSession.RequestHandler setLoggerRequestHandler() { - return params -> { - this.minLoggingLevel = transport.unmarshalFrom(params, new TypeReference() { - }); - - return Mono.empty(); - }; + return this.delegate.loggingNotification(loggingMessageNotification); } // --------------------------------------- // Sampling // --------------------------------------- - private static final TypeReference CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() { - }; /** * Create a new message using the sampling capabilities of the client. The Model @@ -653,17 +348,12 @@ private DefaultMcpSession.RequestHandler setLoggerRequestHandler() { * @see Sampling * Specification + * @deprecated This will be removed in 0.9.0. Use + * {@link McpAsyncServerExchange#createMessage(McpSchema.CreateMessageRequest)}. */ + @Deprecated public Mono createMessage(McpSchema.CreateMessageRequest createMessageRequest) { - - if (this.clientCapabilities == null) { - return Mono.error(new McpError("Client must be initialized. Call the initialize method first!")); - } - if (this.clientCapabilities.sampling() == null) { - return Mono.error(new McpError("Client must be configured with sampling capabilities")); - } - return this.mcpSession.sendRequest(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, createMessageRequest, - CREATE_MESSAGE_RESULT_TYPE_REF); + return this.delegate.createMessage(createMessageRequest); } /** @@ -672,7 +362,1148 @@ public Mono createMessage(McpSchema.CreateMessage * @param protocolVersions the Client supported protocol versions. */ void setProtocolVersions(List protocolVersions) { - this.protocolVersions = protocolVersions; + this.delegate.setProtocolVersions(protocolVersions); + } + + private static class AsyncServerImpl extends McpAsyncServer { + + private final McpServerTransportProvider mcpTransportProvider; + + private final ObjectMapper objectMapper; + + private final McpSchema.ServerCapabilities serverCapabilities; + + private final McpSchema.Implementation serverInfo; + + private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); + + private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); + + private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); + + private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; + + private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); + + AsyncServerImpl(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, + McpServerFeatures.Async features) { + this.mcpTransportProvider = mcpTransportProvider; + this.objectMapper = objectMapper; + this.serverInfo = features.serverInfo(); + this.serverCapabilities = features.serverCapabilities(); + this.tools.addAll(features.tools()); + this.resources.putAll(features.resources()); + this.resourceTemplates.addAll(features.resourceTemplates()); + this.prompts.putAll(features.prompts()); + + Map> requestHandlers = new HashMap<>(); + + // Initialize request handlers for standard MCP methods + + // Ping MUST respond with an empty data, but not NULL response. + requestHandlers.put(McpSchema.METHOD_PING, (exchange, params) -> Mono.just("")); + + // Add tools API handlers if the tool capability is enabled + if (this.serverCapabilities.tools() != null) { + requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); + } + + // Add resources API handlers if provided + if (this.serverCapabilities.resources() != null) { + requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); + } + + // Add prompts API handlers if provider exists + if (this.serverCapabilities.prompts() != null) { + requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); + } + + // Add logging API handlers if the logging capability is enabled + if (this.serverCapabilities.logging() != null) { + requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); + } + + Map notificationHandlers = new HashMap<>(); + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); + + List, Mono>> rootsChangeConsumers = features + .rootsChangeConsumers(); + + if (Utils.isEmpty(rootsChangeConsumers)) { + rootsChangeConsumers = List.of((exchange, + roots) -> Mono.fromRunnable(() -> logger.warn( + "Roots list changed notification, but no consumers provided. Roots list changed: {}", + roots))); + } + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, + asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); + + mcpTransportProvider + .setSessionFactory(transport -> new McpServerSession(UUID.randomUUID().toString(), transport, + this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); + } + + // --------------------------------------- + // Lifecycle Management + // --------------------------------------- + private Mono asyncInitializeRequestHandler( + McpSchema.InitializeRequest initializeRequest) { + return Mono.defer(() -> { + logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", + initializeRequest.protocolVersion(), initializeRequest.capabilities(), + initializeRequest.clientInfo()); + + // The server MUST respond with the highest protocol version it supports + // if + // it does not support the requested (e.g. Client) version. + String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); + + if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { + // If the server supports the requested protocol version, it MUST + // respond + // with the same version. + serverProtocolVersion = initializeRequest.protocolVersion(); + } + else { + logger.warn( + "Client requested unsupported protocol version: {}, so the server will sugggest the {} version instead", + initializeRequest.protocolVersion(), serverProtocolVersion); + } + + return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, + this.serverInfo, null)); + }); + } + + public McpSchema.ServerCapabilities getServerCapabilities() { + return this.serverCapabilities; + } + + public McpSchema.Implementation getServerInfo() { + return this.serverInfo; + } + + @Override + @Deprecated + public ClientCapabilities getClientCapabilities() { + throw new IllegalStateException("This method is deprecated and should not be called"); + } + + @Override + @Deprecated + public McpSchema.Implementation getClientInfo() { + throw new IllegalStateException("This method is deprecated and should not be called"); + } + + @Override + public Mono closeGracefully() { + return this.mcpTransportProvider.closeGracefully(); + } + + @Override + public void close() { + this.mcpTransportProvider.close(); + } + + @Override + @Deprecated + public Mono listRoots() { + return this.listRoots(null); + } + + @Override + @Deprecated + public Mono listRoots(String cursor) { + return Mono.error(new RuntimeException("Not implemented")); + } + + private McpServerSession.NotificationHandler asyncRootsListChangedNotificationHandler( + List, Mono>> rootsChangeConsumers) { + return (exchange, params) -> exchange.listRoots() + .flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) + .flatMap(consumer -> consumer.apply(exchange, listRootsResult.roots())) + .onErrorResume(error -> { + logger.error("Error handling roots list change notification", error); + return Mono.empty(); + }) + .then()); + } + + // --------------------------------------- + // Tool Management + // --------------------------------------- + + @Override + public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) { + if (toolSpecification == null) { + return Mono.error(new McpError("Tool specification must not be null")); + } + if (toolSpecification.tool() == null) { + return Mono.error(new McpError("Tool must not be null")); + } + if (toolSpecification.call() == null) { + return Mono.error(new McpError("Tool call handler must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + // Check for duplicate tool names + if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolSpecification.tool().name()))) { + return Mono + .error(new McpError("Tool with name '" + toolSpecification.tool().name() + "' already exists")); + } + + this.tools.add(toolSpecification); + logger.debug("Added tool handler: {}", toolSpecification.tool().name()); + + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + return Mono.empty(); + }); + } + + @Override + public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistration) { + return this.addTool(toolRegistration.toSpecification()); + } + + @Override + public Mono removeTool(String toolName) { + if (toolName == null) { + return Mono.error(new McpError("Tool name must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + boolean removed = this.tools + .removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName)); + if (removed) { + logger.debug("Removed tool handler: {}", toolName); + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); + }); + } + + @Override + public Mono notifyToolsListChanged() { + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); + } + + private McpServerSession.RequestHandler toolsListRequestHandler() { + return (exchange, params) -> { + List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); + + return Mono.just(new McpSchema.ListToolsResult(tools, null)); + }; + } + + private McpServerSession.RequestHandler toolsCallRequestHandler() { + return (exchange, params) -> { + McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + + Optional toolSpecification = this.tools.stream() + .filter(tr -> callToolRequest.name().equals(tr.tool().name())) + .findAny(); + + if (toolSpecification.isEmpty()) { + return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); + } + + return toolSpecification.map(tool -> tool.call().apply(exchange, callToolRequest.arguments())) + .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); + }; + } + + // --------------------------------------- + // Resource Management + // --------------------------------------- + + @Override + public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceSpecification) { + if (resourceSpecification == null || resourceSpecification.resource() == null) { + return Mono.error(new McpError("Resource must not be null")); + } + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + if (this.resources.putIfAbsent(resourceSpecification.resource().uri(), resourceSpecification) != null) { + return Mono.error(new McpError( + "Resource with URI '" + resourceSpecification.resource().uri() + "' already exists")); + } + logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + }); + } + + @Override + public Mono addResource(McpServerFeatures.AsyncResourceRegistration resourceHandler) { + return this.addResource(resourceHandler.toSpecification()); + } + + @Override + public Mono removeResource(String resourceUri) { + if (resourceUri == null) { + return Mono.error(new McpError("Resource URI must not be null")); + } + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri); + if (removed != null) { + logger.debug("Removed resource handler: {}", resourceUri); + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); + }); + } + + @Override + public Mono notifyResourcesListChanged() { + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); + } + + private McpServerSession.RequestHandler resourcesListRequestHandler() { + return (exchange, params) -> { + var resourceList = this.resources.values() + .stream() + .map(McpServerFeatures.AsyncResourceSpecification::resource) + .toList(); + return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); + }; + } + + private McpServerSession.RequestHandler resourceTemplateListRequestHandler() { + return (exchange, params) -> Mono + .just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null)); + + } + + private McpServerSession.RequestHandler resourcesReadRequestHandler() { + return (exchange, params) -> { + McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + var resourceUri = resourceRequest.uri(); + McpServerFeatures.AsyncResourceSpecification specification = this.resources.get(resourceUri); + if (specification != null) { + return specification.readHandler().apply(exchange, resourceRequest); + } + return Mono.error(new McpError("Resource not found: " + resourceUri)); + }; + } + + // --------------------------------------- + // Prompt Management + // --------------------------------------- + + @Override + public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) { + if (promptSpecification == null) { + return Mono.error(new McpError("Prompt specification must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new McpError("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncPromptSpecification specification = this.prompts + .putIfAbsent(promptSpecification.prompt().name(), promptSpecification); + if (specification != null) { + return Mono.error(new McpError( + "Prompt with name '" + promptSpecification.prompt().name() + "' already exists")); + } + + logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); + + // Servers that declared the listChanged capability SHOULD send a + // notification, + // when the list of available prompts changes + if (this.serverCapabilities.prompts().listChanged()) { + return notifyPromptsListChanged(); + } + return Mono.empty(); + }); + } + + @Override + public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegistration) { + return this.addPrompt(promptRegistration.toSpecification()); + } + + @Override + public Mono removePrompt(String promptName) { + if (promptName == null) { + return Mono.error(new McpError("Prompt name must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new McpError("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncPromptSpecification removed = this.prompts.remove(promptName); + + if (removed != null) { + logger.debug("Removed prompt handler: {}", promptName); + // Servers that declared the listChanged capability SHOULD send a + // notification, when the list of available prompts changes + if (this.serverCapabilities.prompts().listChanged()) { + return this.notifyPromptsListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); + }); + } + + @Override + public Mono notifyPromptsListChanged() { + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); + } + + private McpServerSession.RequestHandler promptsListRequestHandler() { + return (exchange, params) -> { + // TODO: Implement pagination + // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, + // new TypeReference() { + // }); + + var promptList = this.prompts.values() + .stream() + .map(McpServerFeatures.AsyncPromptSpecification::prompt) + .toList(); + + return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); + }; + } + + private McpServerSession.RequestHandler promptsGetRequestHandler() { + return (exchange, params) -> { + McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + + // Implement prompt retrieval logic here + McpServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name()); + if (specification == null) { + return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); + } + + return specification.promptHandler().apply(exchange, promptRequest); + }; + } + + // --------------------------------------- + // Logging Management + // --------------------------------------- + + @Override + public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { + + if (loggingMessageNotification == null) { + return Mono.error(new McpError("Logging message must not be null")); + } + + Map params = this.objectMapper.convertValue(loggingMessageNotification, + new TypeReference>() { + }); + + if (loggingMessageNotification.level().level() < minLoggingLevel.level()) { + return Mono.empty(); + } + + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_MESSAGE, params); + } + + private McpServerSession.RequestHandler setLoggerRequestHandler() { + return (exchange, params) -> { + this.minLoggingLevel = objectMapper.convertValue(params, new TypeReference() { + }); + + return Mono.empty(); + }; + } + + // --------------------------------------- + // Sampling + // --------------------------------------- + + @Override + @Deprecated + public Mono createMessage(McpSchema.CreateMessageRequest createMessageRequest) { + return Mono.error(new RuntimeException("Not implemented")); + } + + @Override + void setProtocolVersions(List protocolVersions) { + this.protocolVersions = protocolVersions; + } + + } + + private static final class LegacyAsyncServer extends McpAsyncServer { + + /** + * The MCP session implementation that manages bidirectional JSON-RPC + * communication between clients and servers. + */ + private final McpClientSession mcpSession; + + private final ServerMcpTransport transport; + + private final McpSchema.ServerCapabilities serverCapabilities; + + private final McpSchema.Implementation serverInfo; + + private McpSchema.ClientCapabilities clientCapabilities; + + private McpSchema.Implementation clientInfo; + + /** + * Thread-safe list of tool handlers that can be modified at runtime. + */ + private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); + + private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); + + private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); + + private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; + + /** + * Supported protocol versions. + */ + private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); + + /** + * Create a new McpAsyncServer with the given transport and capabilities. + * @param mcpTransport The transport layer implementation for MCP communication. + * @param features The MCP server supported features. + */ + LegacyAsyncServer(ServerMcpTransport mcpTransport, McpServerFeatures.Async features) { + + this.serverInfo = features.serverInfo(); + this.serverCapabilities = features.serverCapabilities(); + this.tools.addAll(features.tools()); + this.resources.putAll(features.resources()); + this.resourceTemplates.addAll(features.resourceTemplates()); + this.prompts.putAll(features.prompts()); + + Map> requestHandlers = new HashMap<>(); + + // Initialize request handlers for standard MCP methods + requestHandlers.put(McpSchema.METHOD_INITIALIZE, asyncInitializeRequestHandler()); + + // Ping MUST respond with an empty data, but not NULL response. + requestHandlers.put(McpSchema.METHOD_PING, (params) -> Mono.just("")); + + // Add tools API handlers if the tool capability is enabled + if (this.serverCapabilities.tools() != null) { + requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); + } + + // Add resources API handlers if provided + if (this.serverCapabilities.resources() != null) { + requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); + } + + // Add prompts API handlers if provider exists + if (this.serverCapabilities.prompts() != null) { + requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); + } + + // Add logging API handlers if the logging capability is enabled + if (this.serverCapabilities.logging() != null) { + requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); + } + + Map notificationHandlers = new HashMap<>(); + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (params) -> Mono.empty()); + + List, Mono>> rootsChangeHandlers = features + .rootsChangeConsumers(); + + List, Mono>> rootsChangeConsumers = rootsChangeHandlers.stream() + .map(handler -> (Function, Mono>) (roots) -> handler.apply(null, roots)) + .toList(); + + if (Utils.isEmpty(rootsChangeConsumers)) { + rootsChangeConsumers = List.of((roots) -> Mono.fromRunnable(() -> logger.warn( + "Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); + } + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, + asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); + + this.transport = mcpTransport; + this.mcpSession = new McpClientSession(Duration.ofSeconds(10), mcpTransport, requestHandlers, + notificationHandlers); + } + + @Override + public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) { + throw new IllegalArgumentException( + "McpAsyncServer configured with legacy " + "transport. Use McpServerTransportProvider instead."); + } + + @Override + public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceHandler) { + throw new IllegalArgumentException( + "McpAsyncServer configured with legacy " + "transport. Use McpServerTransportProvider instead."); + } + + @Override + public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) { + throw new IllegalArgumentException( + "McpAsyncServer configured with legacy " + "transport. Use McpServerTransportProvider instead."); + } + + // --------------------------------------- + // Lifecycle Management + // --------------------------------------- + private McpClientSession.RequestHandler asyncInitializeRequestHandler() { + return params -> { + McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(params, + new TypeReference() { + }); + this.clientCapabilities = initializeRequest.capabilities(); + this.clientInfo = initializeRequest.clientInfo(); + logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", + initializeRequest.protocolVersion(), initializeRequest.capabilities(), + initializeRequest.clientInfo()); + + // The server MUST respond with the highest protocol version it supports + // if + // it does not support the requested (e.g. Client) version. + String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); + + if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { + // If the server supports the requested protocol version, it MUST + // respond + // with the same version. + serverProtocolVersion = initializeRequest.protocolVersion(); + } + else { + logger.warn( + "Client requested unsupported protocol version: {}, so the server will sugggest the {} version instead", + initializeRequest.protocolVersion(), serverProtocolVersion); + } + + return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, + this.serverInfo, null)); + }; + } + + /** + * Get the server capabilities that define the supported features and + * functionality. + * @return The server capabilities + */ + public McpSchema.ServerCapabilities getServerCapabilities() { + return this.serverCapabilities; + } + + /** + * Get the server implementation information. + * @return The server implementation details + */ + public McpSchema.Implementation getServerInfo() { + return this.serverInfo; + } + + /** + * Get the client capabilities that define the supported features and + * functionality. + * @return The client capabilities + */ + public ClientCapabilities getClientCapabilities() { + return this.clientCapabilities; + } + + /** + * Get the client implementation information. + * @return The client implementation details + */ + public McpSchema.Implementation getClientInfo() { + return this.clientInfo; + } + + /** + * Gracefully closes the server, allowing any in-progress operations to complete. + * @return A Mono that completes when the server has been closed + */ + public Mono closeGracefully() { + return this.mcpSession.closeGracefully(); + } + + /** + * Close the server immediately. + */ + public void close() { + this.mcpSession.close(); + } + + private static final TypeReference LIST_ROOTS_RESULT_TYPE_REF = new TypeReference<>() { + }; + + /** + * Retrieves the list of all roots provided by the client. + * @return A Mono that emits the list of roots result. + */ + public Mono listRoots() { + return this.listRoots(null); + } + + /** + * Retrieves a paginated list of roots provided by the server. + * @param cursor Optional pagination cursor from a previous list request + * @return A Mono that emits the list of roots result containing + */ + public Mono listRoots(String cursor) { + return this.mcpSession.sendRequest(McpSchema.METHOD_ROOTS_LIST, new McpSchema.PaginatedRequest(cursor), + LIST_ROOTS_RESULT_TYPE_REF); + } + + private McpClientSession.NotificationHandler asyncRootsListChangedNotificationHandler( + List, Mono>> rootsChangeConsumers) { + return params -> listRoots().flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) + .flatMap(consumer -> consumer.apply(listRootsResult.roots())) + .onErrorResume(error -> { + logger.error("Error handling roots list change notification", error); + return Mono.empty(); + }) + .then()); + } + + // --------------------------------------- + // Tool Management + // --------------------------------------- + + /** + * Add a new tool registration at runtime. + * @param toolRegistration The tool registration to add + * @return Mono that completes when clients have been notified of the change + */ + @Override + public Mono addTool(McpServerFeatures.AsyncToolRegistration toolRegistration) { + if (toolRegistration == null) { + return Mono.error(new McpError("Tool registration must not be null")); + } + if (toolRegistration.tool() == null) { + return Mono.error(new McpError("Tool must not be null")); + } + if (toolRegistration.call() == null) { + return Mono.error(new McpError("Tool call handler must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + // Check for duplicate tool names + if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolRegistration.tool().name()))) { + return Mono + .error(new McpError("Tool with name '" + toolRegistration.tool().name() + "' already exists")); + } + + this.tools.add(toolRegistration.toSpecification()); + logger.debug("Added tool handler: {}", toolRegistration.tool().name()); + + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + return Mono.empty(); + }); + } + + /** + * Remove a tool handler at runtime. + * @param toolName The name of the tool handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeTool(String toolName) { + if (toolName == null) { + return Mono.error(new McpError("Tool name must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + boolean removed = this.tools + .removeIf(toolRegistration -> toolRegistration.tool().name().equals(toolName)); + if (removed) { + logger.debug("Removed tool handler: {}", toolName); + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); + }); + } + + /** + * Notifies clients that the list of available tools has changed. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyToolsListChanged() { + return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); + } + + private McpClientSession.RequestHandler toolsListRequestHandler() { + return params -> { + List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); + + return Mono.just(new McpSchema.ListToolsResult(tools, null)); + }; + } + + private McpClientSession.RequestHandler toolsCallRequestHandler() { + return params -> { + McpSchema.CallToolRequest callToolRequest = transport.unmarshalFrom(params, + new TypeReference() { + }); + + Optional toolRegistration = this.tools.stream() + .filter(tr -> callToolRequest.name().equals(tr.tool().name())) + .findAny(); + + if (toolRegistration.isEmpty()) { + return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); + } + + return toolRegistration.map(tool -> tool.call().apply(null, callToolRequest.arguments())) + .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); + }; + } + + // --------------------------------------- + // Resource Management + // --------------------------------------- + + /** + * Add a new resource handler at runtime. + * @param resourceHandler The resource handler to add + * @return Mono that completes when clients have been notified of the change + */ + @Override + public Mono addResource(McpServerFeatures.AsyncResourceRegistration resourceHandler) { + if (resourceHandler == null || resourceHandler.resource() == null) { + return Mono.error(new McpError("Resource must not be null")); + } + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + if (this.resources.putIfAbsent(resourceHandler.resource().uri(), + resourceHandler.toSpecification()) != null) { + return Mono.error(new McpError( + "Resource with URI '" + resourceHandler.resource().uri() + "' already exists")); + } + logger.debug("Added resource handler: {}", resourceHandler.resource().uri()); + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + }); + } + + /** + * Remove a resource handler at runtime. + * @param resourceUri The URI of the resource handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeResource(String resourceUri) { + if (resourceUri == null) { + return Mono.error(new McpError("Resource URI must not be null")); + } + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri); + if (removed != null) { + logger.debug("Removed resource handler: {}", resourceUri); + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); + }); + } + + /** + * Notifies clients that the list of available resources has changed. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyResourcesListChanged() { + return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); + } + + private McpClientSession.RequestHandler resourcesListRequestHandler() { + return params -> { + var resourceList = this.resources.values() + .stream() + .map(McpServerFeatures.AsyncResourceSpecification::resource) + .toList(); + return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); + }; + } + + private McpClientSession.RequestHandler resourceTemplateListRequestHandler() { + return params -> Mono.just(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null)); + + } + + private McpClientSession.RequestHandler resourcesReadRequestHandler() { + return params -> { + McpSchema.ReadResourceRequest resourceRequest = transport.unmarshalFrom(params, + new TypeReference() { + }); + var resourceUri = resourceRequest.uri(); + McpServerFeatures.AsyncResourceSpecification registration = this.resources.get(resourceUri); + if (registration != null) { + return registration.readHandler().apply(null, resourceRequest); + } + return Mono.error(new McpError("Resource not found: " + resourceUri)); + }; + } + + // --------------------------------------- + // Prompt Management + // --------------------------------------- + + /** + * Add a new prompt handler at runtime. + * @param promptRegistration The prompt handler to add + * @return Mono that completes when clients have been notified of the change + */ + @Override + public Mono addPrompt(McpServerFeatures.AsyncPromptRegistration promptRegistration) { + if (promptRegistration == null) { + return Mono.error(new McpError("Prompt registration must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new McpError("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncPromptSpecification registration = this.prompts + .putIfAbsent(promptRegistration.prompt().name(), promptRegistration.toSpecification()); + if (registration != null) { + return Mono.error(new McpError( + "Prompt with name '" + promptRegistration.prompt().name() + "' already exists")); + } + + logger.debug("Added prompt handler: {}", promptRegistration.prompt().name()); + + // Servers that declared the listChanged capability SHOULD send a + // notification, + // when the list of available prompts changes + if (this.serverCapabilities.prompts().listChanged()) { + return notifyPromptsListChanged(); + } + return Mono.empty(); + }); + } + + /** + * Remove a prompt handler at runtime. + * @param promptName The name of the prompt handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removePrompt(String promptName) { + if (promptName == null) { + return Mono.error(new McpError("Prompt name must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new McpError("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncPromptSpecification removed = this.prompts.remove(promptName); + + if (removed != null) { + logger.debug("Removed prompt handler: {}", promptName); + // Servers that declared the listChanged capability SHOULD send a + // notification, when the list of available prompts changes + if (this.serverCapabilities.prompts().listChanged()) { + return this.notifyPromptsListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); + }); + } + + /** + * Notifies clients that the list of available prompts has changed. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyPromptsListChanged() { + return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); + } + + private McpClientSession.RequestHandler promptsListRequestHandler() { + return params -> { + // TODO: Implement pagination + // McpSchema.PaginatedRequest request = transport.unmarshalFrom(params, + // new TypeReference() { + // }); + + var promptList = this.prompts.values() + .stream() + .map(McpServerFeatures.AsyncPromptSpecification::prompt) + .toList(); + + return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); + }; + } + + private McpClientSession.RequestHandler promptsGetRequestHandler() { + return params -> { + McpSchema.GetPromptRequest promptRequest = transport.unmarshalFrom(params, + new TypeReference() { + }); + + // Implement prompt retrieval logic here + McpServerFeatures.AsyncPromptSpecification registration = this.prompts.get(promptRequest.name()); + if (registration == null) { + return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); + } + + return registration.promptHandler().apply(null, promptRequest); + }; + } + + // --------------------------------------- + // Logging Management + // --------------------------------------- + + /** + * Send a logging message notification to all connected clients. Messages below + * the current minimum logging level will be filtered out. + * @param loggingMessageNotification The logging message to send + * @return A Mono that completes when the notification has been sent + */ + public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { + + if (loggingMessageNotification == null) { + return Mono.error(new McpError("Logging message must not be null")); + } + + Map params = this.transport.unmarshalFrom(loggingMessageNotification, + new TypeReference>() { + }); + + if (loggingMessageNotification.level().level() < minLoggingLevel.level()) { + return Mono.empty(); + } + + return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_MESSAGE, params); + } + + /** + * Handles requests to set the minimum logging level. Messages below this level + * will not be sent. + * @return A handler that processes logging level change requests + */ + private McpClientSession.RequestHandler setLoggerRequestHandler() { + return params -> { + this.minLoggingLevel = transport.unmarshalFrom(params, new TypeReference() { + }); + + return Mono.empty(); + }; + } + + // --------------------------------------- + // Sampling + // --------------------------------------- + private static final TypeReference CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() { + }; + + /** + * Create a new message using the sampling capabilities of the client. The Model + * Context Protocol (MCP) provides a standardized way for servers to request LLM + * sampling (“completions” or “generations”) from language models via clients. + * This flow allows clients to maintain control over model access, selection, and + * permissions while enabling servers to leverage AI capabilities—with no server + * API keys necessary. Servers can request text or image-based interactions and + * optionally include context from MCP servers in their prompts. + * @param createMessageRequest The request to create a new message + * @return A Mono that completes when the message has been created + * @throws McpError if the client has not been initialized or does not support + * sampling capabilities + * @throws McpError if the client does not support the createMessage method + * @see McpSchema.CreateMessageRequest + * @see McpSchema.CreateMessageResult + * @see Sampling + * Specification + */ + public Mono createMessage(McpSchema.CreateMessageRequest createMessageRequest) { + + if (this.clientCapabilities == null) { + return Mono.error(new McpError("Client must be initialized. Call the initialize method first!")); + } + if (this.clientCapabilities.sampling() == null) { + return Mono.error(new McpError("Client must be configured with sampling capabilities")); + } + return this.mcpSession.sendRequest(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, createMessageRequest, + CREATE_MESSAGE_RESULT_TYPE_REF); + } + + /** + * This method is package-private and used for test only. Should not be called by + * user code. + * @param protocolVersions the Client supported protocol versions. + */ + void setProtocolVersions(List protocolVersions) { + this.protocolVersions = protocolVersions; + } + } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java new file mode 100644 index 00000000..65862844 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -0,0 +1,104 @@ +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import reactor.core.publisher.Mono; + +/** + * Represents an asynchronous exchange with a Model Context Protocol (MCP) client. The + * exchange provides methods to interact with the client and query its capabilities. + * + * @author Dariusz Jędrzejczyk + */ +public class McpAsyncServerExchange { + + private final McpServerSession session; + + private final McpSchema.ClientCapabilities clientCapabilities; + + private final McpSchema.Implementation clientInfo; + + private static final TypeReference CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() { + }; + + private static final TypeReference LIST_ROOTS_RESULT_TYPE_REF = new TypeReference<>() { + }; + + /** + * Create a new asynchronous exchange with the client. + * @param session The server session representing a 1-1 interaction. + * @param clientCapabilities The client capabilities that define the supported + * features and functionality. + * @param clientInfo The client implementation information. + */ + public McpAsyncServerExchange(McpServerSession session, McpSchema.ClientCapabilities clientCapabilities, + McpSchema.Implementation clientInfo) { + this.session = session; + this.clientCapabilities = clientCapabilities; + this.clientInfo = clientInfo; + } + + /** + * Get the client capabilities that define the supported features and functionality. + * @return The client capabilities + */ + public McpSchema.ClientCapabilities getClientCapabilities() { + return this.clientCapabilities; + } + + /** + * Get the client implementation information. + * @return The client implementation details + */ + public McpSchema.Implementation getClientInfo() { + return this.clientInfo; + } + + /** + * Create a new message using the sampling capabilities of the client. The Model + * Context Protocol (MCP) provides a standardized way for servers to request LLM + * sampling (“completions” or “generations”) from language models via clients. This + * flow allows clients to maintain control over model access, selection, and + * permissions while enabling servers to leverage AI capabilities—with no server API + * keys necessary. Servers can request text or image-based interactions and optionally + * include context from MCP servers in their prompts. + * @param createMessageRequest The request to create a new message + * @return A Mono that completes when the message has been created + * @see McpSchema.CreateMessageRequest + * @see McpSchema.CreateMessageResult + * @see Sampling + * Specification + */ + public Mono createMessage(McpSchema.CreateMessageRequest createMessageRequest) { + if (this.clientCapabilities == null) { + return Mono.error(new McpError("Client must be initialized. Call the initialize method first!")); + } + if (this.clientCapabilities.sampling() == null) { + return Mono.error(new McpError("Client must be configured with sampling capabilities")); + } + return this.session.sendRequest(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, createMessageRequest, + CREATE_MESSAGE_RESULT_TYPE_REF); + } + + /** + * Retrieves the list of all roots provided by the client. + * @return A Mono that emits the list of roots result. + */ + public Mono listRoots() { + return this.listRoots(null); + } + + /** + * Retrieves a paginated list of roots provided by the client. + * @param cursor Optional pagination cursor from a previous list request + * @return A Mono that emits the list of roots result containing + */ + public Mono listRoots(String cursor) { + return this.session.sendRequest(McpSchema.METHOD_ROOTS_LIST, new McpSchema.PaginatedRequest(cursor), + LIST_ROOTS_RESULT_TYPE_REF); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index 54c7a28f..d8dfcb01 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -5,14 +5,19 @@ package io.modelcontextprotocol.server; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; +import java.util.stream.Collectors; +import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.spec.ServerMcpTransport; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; @@ -49,45 +54,50 @@ *

    * The class provides factory methods to create either: *

      - *
    • {@link McpAsyncServer} for non-blocking operations with CompletableFuture responses + *
    • {@link McpAsyncServer} for non-blocking operations with reactive responses *
    • {@link McpSyncServer} for blocking operations with direct responses *
    * *

    * Example of creating a basic synchronous server:

    {@code
    - * McpServer.sync(transport)
    + * McpServer.sync(transportProvider)
      *     .serverInfo("my-server", "1.0.0")
      *     .tool(new Tool("calculator", "Performs calculations", schema),
    - *           args -> new CallToolResult("Result: " + calculate(args)))
    + *           (exchange, args) -> new CallToolResult("Result: " + calculate(args)))
      *     .build();
      * }
    * * Example of creating a basic asynchronous server:
    {@code
    - * McpServer.async(transport)
    + * McpServer.async(transportProvider)
      *     .serverInfo("my-server", "1.0.0")
      *     .tool(new Tool("calculator", "Performs calculations", schema),
    - *           args -> Mono.just(new CallToolResult("Result: " + calculate(args))))
    + *           (exchange, args) -> Mono.fromSupplier(() -> calculate(args))
    + *               .map(result -> new CallToolResult("Result: " + result)))
      *     .build();
      * }
    * *

    * Example with comprehensive asynchronous configuration:

    {@code
    - * McpServer.async(transport)
    + * McpServer.async(transportProvider)
      *     .serverInfo("advanced-server", "2.0.0")
      *     .capabilities(new ServerCapabilities(...))
      *     // Register tools
      *     .tools(
    - *         new McpServerFeatures.AsyncToolRegistration(calculatorTool,
    - *             args -> Mono.just(new CallToolResult("Result: " + calculate(args)))),
    - *         new McpServerFeatures.AsyncToolRegistration(weatherTool,
    - *             args -> Mono.just(new CallToolResult("Weather: " + getWeather(args))))
    + *         new McpServerFeatures.AsyncToolSpecification(calculatorTool,
    + *             (exchange, args) -> Mono.fromSupplier(() -> calculate(args))
    + *                 .map(result -> new CallToolResult("Result: " + result))),
    + *         new McpServerFeatures.AsyncToolSpecification(weatherTool,
    + *             (exchange, args) -> Mono.fromSupplier(() -> getWeather(args))
    + *                 .map(result -> new CallToolResult("Weather: " + result)))
      *     )
      *     // Register resources
      *     .resources(
    - *         new McpServerFeatures.AsyncResourceRegistration(fileResource,
    - *             req -> Mono.just(new ReadResourceResult(readFile(req)))),
    - *         new McpServerFeatures.AsyncResourceRegistration(dbResource,
    - *             req -> Mono.just(new ReadResourceResult(queryDb(req))))
    + *         new McpServerFeatures.AsyncResourceSpecification(fileResource,
    + *             (exchange, req) -> Mono.fromSupplier(() -> readFile(req))
    + *                 .map(ReadResourceResult::new)),
    + *         new McpServerFeatures.AsyncResourceSpecification(dbResource,
    + *             (exchange, req) -> Mono.fromSupplier(() -> queryDb(req))
    + *                 .map(ReadResourceResult::new))
      *     )
      *     // Add resource templates
      *     .resourceTemplates(
    @@ -96,10 +106,12 @@
      *     )
      *     // Register prompts
      *     .prompts(
    - *         new McpServerFeatures.AsyncPromptRegistration(analysisPrompt,
    - *             req -> Mono.just(new GetPromptResult(generateAnalysisPrompt(req)))),
    + *         new McpServerFeatures.AsyncPromptSpecification(analysisPrompt,
    + *             (exchange, req) -> Mono.fromSupplier(() -> generateAnalysisPrompt(req))
    + *                 .map(GetPromptResult::new)),
      *         new McpServerFeatures.AsyncPromptRegistration(summaryPrompt,
    - *             req -> Mono.just(new GetPromptResult(generateSummaryPrompt(req))))
    + *             (exchange, req) -> Mono.fromSupplier(() -> generateSummaryPrompt(req))
    + *                 .map(GetPromptResult::new))
      *     )
      *     .build();
      * }
    @@ -108,37 +120,896 @@ * @author Dariusz Jędrzejczyk * @see McpAsyncServer * @see McpSyncServer - * @see McpTransport + * @see McpServerTransportProvider */ public interface McpServer { /** * Starts building a synchronous MCP server that provides blocking operations. - * Synchronous servers process each request to completion before handling the next - * one, making them simpler to implement but potentially less performant for - * concurrent operations. + * Synchronous servers block the current Thread's execution upon each request before + * giving the control back to the caller, making them simpler to implement but + * potentially less scalable for concurrent operations. + * @param transportProvider The transport layer implementation for MCP communication. + * @return A new instance of {@link SyncSpecification} for configuring the server. + */ + static SyncSpecification sync(McpServerTransportProvider transportProvider) { + return new SyncSpecification(transportProvider); + } + + /** + * Starts building a synchronous MCP server that provides blocking operations. + * Synchronous servers block the current Thread's execution upon each request before + * giving the control back to the caller, making them simpler to implement but + * potentially less scalable for concurrent operations. * @param transport The transport layer implementation for MCP communication * @return A new instance of {@link SyncSpec} for configuring the server. + * @deprecated This method will be removed in 0.9.0. Use + * {@link #sync(McpServerTransportProvider)} instead. */ + @Deprecated static SyncSpec sync(ServerMcpTransport transport) { return new SyncSpec(transport); } - /** - * Starts building an asynchronous MCP server that provides blocking operations. - * Asynchronous servers can handle multiple requests concurrently using a functional - * paradigm with non-blocking server transports, making them more efficient for - * high-concurrency scenarios but more complex to implement. - * @param transport The transport layer implementation for MCP communication - * @return A new instance of {@link SyncSpec} for configuring the server. - */ - static AsyncSpec async(ServerMcpTransport transport) { - return new AsyncSpec(transport); + /** + * Starts building an asynchronous MCP server that provides non-blocking operations. + * Asynchronous servers can handle multiple requests concurrently on a single Thread + * using a functional paradigm with non-blocking server transports, making them more + * scalable for high-concurrency scenarios but more complex to implement. + * @param transportProvider The transport layer implementation for MCP communication. + * @return A new instance of {@link AsyncSpecification} for configuring the server. + */ + static AsyncSpecification async(McpServerTransportProvider transportProvider) { + return new AsyncSpecification(transportProvider); + } + + /** + * Starts building an asynchronous MCP server that provides non-blocking operations. + * Asynchronous servers can handle multiple requests concurrently on a single Thread + * using a functional paradigm with non-blocking server transports, making them more + * scalable for high-concurrency scenarios but more complex to implement. + * @param transport The transport layer implementation for MCP communication + * @return A new instance of {@link AsyncSpec} for configuring the server. + * @deprecated This method will be removed in 0.9.0. Use + * {@link #async(McpServerTransportProvider)} instead. + */ + @Deprecated + static AsyncSpec async(ServerMcpTransport transport) { + return new AsyncSpec(transport); + } + + /** + * Asynchronous server specification. + */ + class AsyncSpecification { + + private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", + "1.0.0"); + + private final McpServerTransportProvider transportProvider; + + private ObjectMapper objectMapper; + + private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; + + private McpSchema.ServerCapabilities serverCapabilities; + + /** + * The Model Context Protocol (MCP) allows servers to expose tools that can be + * invoked by language models. Tools enable models to interact with external + * systems, such as querying databases, calling APIs, or performing computations. + * Each tool is uniquely identified by a name and includes metadata describing its + * schema. + */ + private final List tools = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resources to clients. Resources allow servers to share data that + * provides context to language models, such as files, database schemas, or + * application-specific information. Each resource is uniquely identified by a + * URI. + */ + private final Map resources = new HashMap<>(); + + private final List resourceTemplates = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose prompt templates to clients. Prompts allow servers to provide structured + * messages and instructions for interacting with language models. Clients can + * discover available prompts, retrieve their contents, and provide arguments to + * customize them. + */ + private final Map prompts = new HashMap<>(); + + private final List, Mono>> rootsChangeHandlers = new ArrayList<>(); + + private AsyncSpecification(McpServerTransportProvider transportProvider) { + Assert.notNull(transportProvider, "Transport provider must not be null"); + this.transportProvider = transportProvider; + } + + /** + * Sets the server implementation information that will be shared with clients + * during connection initialization. This helps with version compatibility, + * debugging, and server identification. + * @param serverInfo The server implementation details including name and version. + * Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverInfo is null + */ + public AsyncSpecification serverInfo(McpSchema.Implementation serverInfo) { + Assert.notNull(serverInfo, "Server info must not be null"); + this.serverInfo = serverInfo; + return this; + } + + /** + * Sets the server implementation information using name and version strings. This + * is a convenience method alternative to + * {@link #serverInfo(McpSchema.Implementation)}. + * @param name The server name. Must not be null or empty. + * @param version The server version. Must not be null or empty. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if name or version is null or empty + * @see #serverInfo(McpSchema.Implementation) + */ + public AsyncSpecification serverInfo(String name, String version) { + Assert.hasText(name, "Name must not be null or empty"); + Assert.hasText(version, "Version must not be null or empty"); + this.serverInfo = new McpSchema.Implementation(name, version); + return this; + } + + /** + * Sets the server capabilities that will be advertised to clients during + * connection initialization. Capabilities define what features the server + * supports, such as: + *
      + *
    • Tool execution + *
    • Resource access + *
    • Prompt handling + *
    + * @param serverCapabilities The server capabilities configuration. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverCapabilities is null + */ + public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + Assert.notNull(serverCapabilities, "Server capabilities must not be null"); + this.serverCapabilities = serverCapabilities; + return this; + } + + /** + * Adds a single tool with its implementation handler to the server. This is a + * convenience method for registering individual tools without creating a + * {@link McpServerFeatures.AsyncToolSpecification} explicitly. + * + *

    + * Example usage:

    {@code
    +		 * .tool(
    +		 *     new Tool("calculator", "Performs calculations", schema),
    +		 *     (exchange, args) -> Mono.fromSupplier(() -> calculate(args))
    +		 *         .map(result -> new CallToolResult("Result: " + result))
    +		 * )
    +		 * }
    + * @param tool The tool definition including name, description, and schema. Must + * not be null. + * @param handler The function that implements the tool's logic. Must not be null. + * The function's first argument is an {@link McpAsyncServerExchange} upon which + * the server can interact with the connected client. The second argument is the + * map of arguments passed to the tool. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if tool or handler is null + */ + public AsyncSpecification tool(McpSchema.Tool tool, + BiFunction, Mono> handler) { + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(handler, "Handler must not be null"); + + this.tools.add(new McpServerFeatures.AsyncToolSpecification(tool, handler)); + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using a List. This method + * is useful when tools are dynamically generated or loaded from a configuration + * source. + * @param toolSpecifications The list of tool specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + * @see #tools(McpServerFeatures.AsyncToolSpecification...) + */ + public AsyncSpecification tools(List toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + this.tools.addAll(toolSpecifications); + return this; + } + + /** + * Adds multiple tools with their handlers to the server using varargs. This + * method provides a convenient way to register multiple tools inline. + * + *

    + * Example usage:

    {@code
    +		 * .tools(
    +		 *     new McpServerFeatures.AsyncToolSpecification(calculatorTool, calculatorHandler),
    +		 *     new McpServerFeatures.AsyncToolSpecification(weatherTool, weatherHandler),
    +		 *     new McpServerFeatures.AsyncToolSpecification(fileManagerTool, fileManagerHandler)
    +		 * )
    +		 * }
    + * @param toolSpecifications The tool specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + * @see #tools(List) + */ + public AsyncSpecification tools(McpServerFeatures.AsyncToolSpecification... toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + for (McpServerFeatures.AsyncToolSpecification tool : toolSpecifications) { + this.tools.add(tool); + } + return this; + } + + /** + * Registers multiple resources with their handlers using a Map. This method is + * useful when resources are dynamically generated or loaded from a configuration + * source. + * @param resourceSpecifications Map of resource name to specification. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpServerFeatures.AsyncResourceSpecification...) + */ + public AsyncSpecification resources( + Map resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); + this.resources.putAll(resourceSpecifications); + return this; + } + + /** + * Registers multiple resources with their handlers using a List. This method is + * useful when resources need to be added in bulk from a collection. + * @param resourceSpecifications List of resource specifications. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpServerFeatures.AsyncResourceSpecification...) + */ + public AsyncSpecification resources(List resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (McpServerFeatures.AsyncResourceSpecification resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Registers multiple resources with their handlers using varargs. This method + * provides a convenient way to register multiple resources inline. + * + *

    + * Example usage:

    {@code
    +		 * .resources(
    +		 *     new McpServerFeatures.AsyncResourceSpecification(fileResource, fileHandler),
    +		 *     new McpServerFeatures.AsyncResourceSpecification(dbResource, dbHandler),
    +		 *     new McpServerFeatures.AsyncResourceSpecification(apiResource, apiHandler)
    +		 * )
    +		 * }
    + * @param resourceSpecifications The resource specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + */ + public AsyncSpecification resources(McpServerFeatures.AsyncResourceSpecification... resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (McpServerFeatures.AsyncResourceSpecification resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Sets the resource templates that define patterns for dynamic resource access. + * Templates use URI patterns with placeholders that can be filled at runtime. + * + *

    + * Example usage:

    {@code
    +		 * .resourceTemplates(
    +		 *     new ResourceTemplate("file://{path}", "Access files by path"),
    +		 *     new ResourceTemplate("db://{table}/{id}", "Access database records")
    +		 * )
    +		 * }
    + * @param resourceTemplates List of resource templates. If null, clears existing + * templates. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + * @see #resourceTemplates(ResourceTemplate...) + */ + public AsyncSpecification resourceTemplates(List resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + this.resourceTemplates.addAll(resourceTemplates); + return this; + } + + /** + * Sets the resource templates using varargs for convenience. This is an + * alternative to {@link #resourceTemplates(List)}. + * @param resourceTemplates The resource templates to set. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + * @see #resourceTemplates(List) + */ + public AsyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + for (ResourceTemplate resourceTemplate : resourceTemplates) { + this.resourceTemplates.add(resourceTemplate); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using a Map. This method is + * useful when prompts are dynamically generated or loaded from a configuration + * source. + * + *

    + * Example usage:

    {@code
    +		 * .prompts(Map.of("analysis", new McpServerFeatures.AsyncPromptSpecification(
    +		 *     new Prompt("analysis", "Code analysis template"),
    +		 *     request -> Mono.fromSupplier(() -> generateAnalysisPrompt(request))
    +		 *         .map(GetPromptResult::new)
    +		 * )));
    +		 * }
    + * @param prompts Map of prompt name to specification. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public AsyncSpecification prompts(Map prompts) { + Assert.notNull(prompts, "Prompts map must not be null"); + this.prompts.putAll(prompts); + return this; + } + + /** + * Registers multiple prompts with their handlers using a List. This method is + * useful when prompts need to be added in bulk from a collection. + * @param prompts List of prompt specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + * @see #prompts(McpServerFeatures.AsyncPromptSpecification...) + */ + public AsyncSpecification prompts(List prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (McpServerFeatures.AsyncPromptSpecification prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using varargs. This method + * provides a convenient way to register multiple prompts inline. + * + *

    + * Example usage:

    {@code
    +		 * .prompts(
    +		 *     new McpServerFeatures.AsyncPromptSpecification(analysisPrompt, analysisHandler),
    +		 *     new McpServerFeatures.AsyncPromptSpecification(summaryPrompt, summaryHandler),
    +		 *     new McpServerFeatures.AsyncPromptSpecification(reviewPrompt, reviewHandler)
    +		 * )
    +		 * }
    + * @param prompts The prompt specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public AsyncSpecification prompts(McpServerFeatures.AsyncPromptSpecification... prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (McpServerFeatures.AsyncPromptSpecification prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers a consumer that will be notified when the list of roots changes. This + * is useful for updating resource availability dynamically, such as when new + * files are added or removed. + * @param handler The handler to register. Must not be null. The function's first + * argument is an {@link McpAsyncServerExchange} upon which the server can + * interact with the connected client. The second argument is the list of roots. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumer is null + */ + public AsyncSpecification rootsChangeHandler( + BiFunction, Mono> handler) { + Assert.notNull(handler, "Consumer must not be null"); + this.rootsChangeHandlers.add(handler); + return this; + } + + /** + * Registers multiple consumers that will be notified when the list of roots + * changes. This method is useful when multiple consumers need to be registered at + * once. + * @param handlers The list of handlers to register. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumers is null + * @see #rootsChangeHandler(BiFunction) + */ + public AsyncSpecification rootsChangeHandlers( + List, Mono>> handlers) { + Assert.notNull(handlers, "Handlers list must not be null"); + this.rootsChangeHandlers.addAll(handlers); + return this; + } + + /** + * Registers multiple consumers that will be notified when the list of roots + * changes using varargs. This method provides a convenient way to register + * multiple consumers inline. + * @param handlers The handlers to register. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumers is null + * @see #rootsChangeHandlers(List) + */ + public AsyncSpecification rootsChangeHandlers( + @SuppressWarnings("unchecked") BiFunction, Mono>... handlers) { + Assert.notNull(handlers, "Handlers list must not be null"); + return this.rootsChangeHandlers(Arrays.asList(handlers)); + } + + /** + * Sets the object mapper to use for serializing and deserializing JSON messages. + * @param objectMapper the instance to use. Must not be null. + * @return This builder instance for method chaining. + * @throws IllegalArgumentException if objectMapper is null + */ + public AsyncSpecification objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Builds an asynchronous MCP server that provides non-blocking operations. + * @return A new instance of {@link McpAsyncServer} configured with this builder's + * settings. + */ + public McpAsyncServer build() { + var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, + this.resources, this.resourceTemplates, this.prompts, this.rootsChangeHandlers); + var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + return new McpAsyncServer(this.transportProvider, mapper, features); + } + + } + + /** + * Synchronous server specification. + */ + class SyncSpecification { + + private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", + "1.0.0"); + + private final McpServerTransportProvider transportProvider; + + private ObjectMapper objectMapper; + + private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; + + private McpSchema.ServerCapabilities serverCapabilities; + + /** + * The Model Context Protocol (MCP) allows servers to expose tools that can be + * invoked by language models. Tools enable models to interact with external + * systems, such as querying databases, calling APIs, or performing computations. + * Each tool is uniquely identified by a name and includes metadata describing its + * schema. + */ + private final List tools = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resources to clients. Resources allow servers to share data that + * provides context to language models, such as files, database schemas, or + * application-specific information. Each resource is uniquely identified by a + * URI. + */ + private final Map resources = new HashMap<>(); + + private final List resourceTemplates = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose prompt templates to clients. Prompts allow servers to provide structured + * messages and instructions for interacting with language models. Clients can + * discover available prompts, retrieve their contents, and provide arguments to + * customize them. + */ + private final Map prompts = new HashMap<>(); + + private final List>> rootsChangeHandlers = new ArrayList<>(); + + private SyncSpecification(McpServerTransportProvider transportProvider) { + Assert.notNull(transportProvider, "Transport provider must not be null"); + this.transportProvider = transportProvider; + } + + /** + * Sets the server implementation information that will be shared with clients + * during connection initialization. This helps with version compatibility, + * debugging, and server identification. + * @param serverInfo The server implementation details including name and version. + * Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverInfo is null + */ + public SyncSpecification serverInfo(McpSchema.Implementation serverInfo) { + Assert.notNull(serverInfo, "Server info must not be null"); + this.serverInfo = serverInfo; + return this; + } + + /** + * Sets the server implementation information using name and version strings. This + * is a convenience method alternative to + * {@link #serverInfo(McpSchema.Implementation)}. + * @param name The server name. Must not be null or empty. + * @param version The server version. Must not be null or empty. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if name or version is null or empty + * @see #serverInfo(McpSchema.Implementation) + */ + public SyncSpecification serverInfo(String name, String version) { + Assert.hasText(name, "Name must not be null or empty"); + Assert.hasText(version, "Version must not be null or empty"); + this.serverInfo = new McpSchema.Implementation(name, version); + return this; + } + + /** + * Sets the server capabilities that will be advertised to clients during + * connection initialization. Capabilities define what features the server + * supports, such as: + *
      + *
    • Tool execution + *
    • Resource access + *
    • Prompt handling + *
    + * @param serverCapabilities The server capabilities configuration. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverCapabilities is null + */ + public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + Assert.notNull(serverCapabilities, "Server capabilities must not be null"); + this.serverCapabilities = serverCapabilities; + return this; + } + + /** + * Adds a single tool with its implementation handler to the server. This is a + * convenience method for registering individual tools without creating a + * {@link McpServerFeatures.SyncToolSpecification} explicitly. + * + *

    + * Example usage:

    {@code
    +		 * .tool(
    +		 *     new Tool("calculator", "Performs calculations", schema),
    +		 *     (exchange, args) -> new CallToolResult("Result: " + calculate(args))
    +		 * )
    +		 * }
    + * @param tool The tool definition including name, description, and schema. Must + * not be null. + * @param handler The function that implements the tool's logic. Must not be null. + * The function's first argument is an {@link McpSyncServerExchange} upon which + * the server can interact with the connected client. The second argument is the + * list of arguments passed to the tool. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if tool or handler is null + */ + public SyncSpecification tool(McpSchema.Tool tool, + BiFunction, McpSchema.CallToolResult> handler) { + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(handler, "Handler must not be null"); + + this.tools.add(new McpServerFeatures.SyncToolSpecification(tool, handler)); + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using a List. This method + * is useful when tools are dynamically generated or loaded from a configuration + * source. + * @param toolSpecifications The list of tool specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + * @see #tools(McpServerFeatures.SyncToolSpecification...) + */ + public SyncSpecification tools(List toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + this.tools.addAll(toolSpecifications); + return this; + } + + /** + * Adds multiple tools with their handlers to the server using varargs. This + * method provides a convenient way to register multiple tools inline. + * + *

    + * Example usage:

    {@code
    +		 * .tools(
    +		 *     new ToolSpecification(calculatorTool, calculatorHandler),
    +		 *     new ToolSpecification(weatherTool, weatherHandler),
    +		 *     new ToolSpecification(fileManagerTool, fileManagerHandler)
    +		 * )
    +		 * }
    + * @param toolSpecifications The tool specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + * @see #tools(List) + */ + public SyncSpecification tools(McpServerFeatures.SyncToolSpecification... toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + for (McpServerFeatures.SyncToolSpecification tool : toolSpecifications) { + this.tools.add(tool); + } + return this; + } + + /** + * Registers multiple resources with their handlers using a Map. This method is + * useful when resources are dynamically generated or loaded from a configuration + * source. + * @param resourceSpecifications Map of resource name to specification. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpServerFeatures.SyncResourceSpecification...) + */ + public SyncSpecification resources( + Map resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); + this.resources.putAll(resourceSpecifications); + return this; + } + + /** + * Registers multiple resources with their handlers using a List. This method is + * useful when resources need to be added in bulk from a collection. + * @param resourceSpecifications List of resource specifications. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpServerFeatures.SyncResourceSpecification...) + */ + public SyncSpecification resources(List resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (McpServerFeatures.SyncResourceSpecification resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Registers multiple resources with their handlers using varargs. This method + * provides a convenient way to register multiple resources inline. + * + *

    + * Example usage:

    {@code
    +		 * .resources(
    +		 *     new ResourceSpecification(fileResource, fileHandler),
    +		 *     new ResourceSpecification(dbResource, dbHandler),
    +		 *     new ResourceSpecification(apiResource, apiHandler)
    +		 * )
    +		 * }
    + * @param resourceSpecifications The resource specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + */ + public SyncSpecification resources(McpServerFeatures.SyncResourceSpecification... resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (McpServerFeatures.SyncResourceSpecification resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Sets the resource templates that define patterns for dynamic resource access. + * Templates use URI patterns with placeholders that can be filled at runtime. + * + *

    + * Example usage:

    {@code
    +		 * .resourceTemplates(
    +		 *     new ResourceTemplate("file://{path}", "Access files by path"),
    +		 *     new ResourceTemplate("db://{table}/{id}", "Access database records")
    +		 * )
    +		 * }
    + * @param resourceTemplates List of resource templates. If null, clears existing + * templates. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + * @see #resourceTemplates(ResourceTemplate...) + */ + public SyncSpecification resourceTemplates(List resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + this.resourceTemplates.addAll(resourceTemplates); + return this; + } + + /** + * Sets the resource templates using varargs for convenience. This is an + * alternative to {@link #resourceTemplates(List)}. + * @param resourceTemplates The resource templates to set. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null + * @see #resourceTemplates(List) + */ + public SyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + for (ResourceTemplate resourceTemplate : resourceTemplates) { + this.resourceTemplates.add(resourceTemplate); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using a Map. This method is + * useful when prompts are dynamically generated or loaded from a configuration + * source. + * + *

    + * Example usage:

    {@code
    +		 * Map prompts = new HashMap<>();
    +		 * prompts.put("analysis", new PromptSpecification(
    +		 *     new Prompt("analysis", "Code analysis template"),
    +		 *     (exchange, request) -> new GetPromptResult(generateAnalysisPrompt(request))
    +		 * ));
    +		 * .prompts(prompts)
    +		 * }
    + * @param prompts Map of prompt name to specification. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public SyncSpecification prompts(Map prompts) { + Assert.notNull(prompts, "Prompts map must not be null"); + this.prompts.putAll(prompts); + return this; + } + + /** + * Registers multiple prompts with their handlers using a List. This method is + * useful when prompts need to be added in bulk from a collection. + * @param prompts List of prompt specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + * @see #prompts(McpServerFeatures.SyncPromptSpecification...) + */ + public SyncSpecification prompts(List prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (McpServerFeatures.SyncPromptSpecification prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using varargs. This method + * provides a convenient way to register multiple prompts inline. + * + *

    + * Example usage:

    {@code
    +		 * .prompts(
    +		 *     new PromptSpecification(analysisPrompt, analysisHandler),
    +		 *     new PromptSpecification(summaryPrompt, summaryHandler),
    +		 *     new PromptSpecification(reviewPrompt, reviewHandler)
    +		 * )
    +		 * }
    + * @param prompts The prompt specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public SyncSpecification prompts(McpServerFeatures.SyncPromptSpecification... prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (McpServerFeatures.SyncPromptSpecification prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers a consumer that will be notified when the list of roots changes. This + * is useful for updating resource availability dynamically, such as when new + * files are added or removed. + * @param handler The handler to register. Must not be null. The function's first + * argument is an {@link McpSyncServerExchange} upon which the server can interact + * with the connected client. The second argument is the list of roots. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumer is null + */ + public SyncSpecification rootsChangeHandler(BiConsumer> handler) { + Assert.notNull(handler, "Consumer must not be null"); + this.rootsChangeHandlers.add(handler); + return this; + } + + /** + * Registers multiple consumers that will be notified when the list of roots + * changes. This method is useful when multiple consumers need to be registered at + * once. + * @param handlers The list of handlers to register. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumers is null + * @see #rootsChangeHandler(BiConsumer) + */ + public SyncSpecification rootsChangeHandlers( + List>> handlers) { + Assert.notNull(handlers, "Handlers list must not be null"); + this.rootsChangeHandlers.addAll(handlers); + return this; + } + + /** + * Registers multiple consumers that will be notified when the list of roots + * changes using varargs. This method provides a convenient way to register + * multiple consumers inline. + * @param handlers The handlers to register. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if consumers is null + * @see #rootsChangeHandlers(List) + */ + public SyncSpecification rootsChangeHandlers( + BiConsumer>... handlers) { + Assert.notNull(handlers, "Handlers list must not be null"); + return this.rootsChangeHandlers(List.of(handlers)); + } + + /** + * Sets the object mapper to use for serializing and deserializing JSON messages. + * @param objectMapper the instance to use. Must not be null. + * @return This builder instance for method chaining. + * @throws IllegalArgumentException if objectMapper is null + */ + public SyncSpecification objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Builds a synchronous MCP server that provides blocking operations. + * @return A new instance of {@link McpSyncServer} configured with this builder's + * settings. + */ + public McpSyncServer build() { + McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, + this.tools, this.resources, this.resourceTemplates, this.prompts, this.rootsChangeHandlers); + McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures); + var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures); + + return new McpSyncServer(asyncServer); + } + } /** * Asynchronous server specification. + * + * @deprecated */ + @Deprecated class AsyncSpec { private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", @@ -146,6 +1017,8 @@ class AsyncSpec { private final ServerMcpTransport transport; + private ObjectMapper objectMapper; + private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; private McpSchema.ServerCapabilities serverCapabilities; @@ -507,16 +1380,37 @@ public AsyncSpec rootsChangeConsumers( * settings */ public McpAsyncServer build() { - return new McpAsyncServer(this.transport, - new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, this.resources, - this.resourceTemplates, this.prompts, this.rootsChangeConsumers)); + var tools = this.tools.stream().map(McpServerFeatures.AsyncToolRegistration::toSpecification).toList(); + + var resources = this.resources.entrySet() + .stream() + .map(entry -> Map.entry(entry.getKey(), entry.getValue().toSpecification())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + var prompts = this.prompts.entrySet() + .stream() + .map(entry -> Map.entry(entry.getKey(), entry.getValue().toSpecification())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + var rootsChangeHandlers = this.rootsChangeConsumers.stream() + .map(consumer -> (BiFunction, Mono>) (exchange, + roots) -> consumer.apply(roots)) + .toList(); + + var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, tools, resources, + this.resourceTemplates, prompts, rootsChangeHandlers); + + return new McpAsyncServer(this.transport, features); } } /** * Synchronous server specification. + * + * @deprecated */ + @Deprecated class SyncSpec { private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", @@ -524,6 +1418,10 @@ class SyncSpec { private final ServerMcpTransport transport; + private final McpServerTransportProvider transportProvider; + + private ObjectMapper objectMapper; + private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; private McpSchema.ServerCapabilities serverCapabilities; @@ -559,9 +1457,16 @@ class SyncSpec { private final List>> rootsChangeConsumers = new ArrayList<>(); + private SyncSpec(McpServerTransportProvider transportProvider) { + Assert.notNull(transportProvider, "Transport provider must not be null"); + this.transportProvider = transportProvider; + this.transport = null; + } + private SyncSpec(ServerMcpTransport transport) { Assert.notNull(transport, "Transport must not be null"); this.transport = transport; + this.transportProvider = null; } /** @@ -620,7 +1525,7 @@ public SyncSpec capabilities(McpSchema.ServerCapabilities serverCapabilities) { /** * Adds a single tool with its implementation handler to the server. This is a * convenience method for registering individual tools without creating a - * {@link ToolRegistration} explicitly. + * {@link McpServerFeatures.SyncToolRegistration} explicitly. * *

    * Example usage:

    {@code
    @@ -886,10 +1791,30 @@ public SyncSpec rootsChangeConsumers(Consumer>... consumers
     		 * settings
     		 */
     		public McpSyncServer build() {
    +			var tools = this.tools.stream().map(McpServerFeatures.SyncToolRegistration::toSpecification).toList();
    +
    +			var resources = this.resources.entrySet()
    +				.stream()
    +				.map(entry -> Map.entry(entry.getKey(), entry.getValue().toSpecification()))
    +				.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
    +
    +			var prompts = this.prompts.entrySet()
    +				.stream()
    +				.map(entry -> Map.entry(entry.getKey(), entry.getValue().toSpecification()))
    +				.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
    +
    +			var rootsChangeHandlers = this.rootsChangeConsumers.stream()
    +				.map(consumer -> (BiConsumer>) (exchange, roots) -> consumer
    +					.accept(roots))
    +				.toList();
    +
     			McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities,
    -					this.tools, this.resources, this.resourceTemplates, this.prompts, this.rootsChangeConsumers);
    -			return new McpSyncServer(
    -					new McpAsyncServer(this.transport, McpServerFeatures.Async.fromSync(syncFeatures)));
    +					tools, resources, this.resourceTemplates, prompts, rootsChangeHandlers);
    +
    +			McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures);
    +			var asyncServer = new McpAsyncServer(this.transport, asyncFeatures);
    +
    +			return new McpSyncServer(asyncServer);
     		}
     
     	}
    diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java
    index c8f8399a..5aeeadd7 100644
    --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java
    +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java
    @@ -8,7 +8,8 @@
     import java.util.HashMap;
     import java.util.List;
     import java.util.Map;
    -import java.util.function.Consumer;
    +import java.util.function.BiConsumer;
    +import java.util.function.BiFunction;
     import java.util.function.Function;
     
     import io.modelcontextprotocol.spec.McpSchema;
    @@ -29,35 +30,35 @@ public class McpServerFeatures {
     	 *
     	 * @param serverInfo The server implementation details
     	 * @param serverCapabilities The server capabilities
    -	 * @param tools The list of tool registrations
    -	 * @param resources The map of resource registrations
    +	 * @param tools The list of tool specifications
    +	 * @param resources The map of resource specifications
     	 * @param resourceTemplates The list of resource templates
    -	 * @param prompts The map of prompt registrations
    +	 * @param prompts The map of prompt specifications
     	 * @param rootsChangeConsumers The list of consumers that will be notified when the
     	 * roots list changes
     	 */
     	record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities,
    -			List tools, Map resources,
    +			List tools, Map resources,
     			List resourceTemplates,
    -			Map prompts,
    -			List, Mono>> rootsChangeConsumers) {
    +			Map prompts,
    +			List, Mono>> rootsChangeConsumers) {
     
     		/**
     		 * Create an instance and validate the arguments.
     		 * @param serverInfo The server implementation details
     		 * @param serverCapabilities The server capabilities
    -		 * @param tools The list of tool registrations
    -		 * @param resources The map of resource registrations
    +		 * @param tools The list of tool specifications
    +		 * @param resources The map of resource specifications
     		 * @param resourceTemplates The list of resource templates
    -		 * @param prompts The map of prompt registrations
    +		 * @param prompts The map of prompt specifications
     		 * @param rootsChangeConsumers The list of consumers that will be notified when
     		 * the roots list changes
     		 */
     		Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities,
    -				List tools, Map resources,
    +				List tools, Map resources,
     				List resourceTemplates,
    -				Map prompts,
    -				List, Mono>> rootsChangeConsumers) {
    +				Map prompts,
    +				List, Mono>> rootsChangeConsumers) {
     
     			Assert.notNull(serverInfo, "Server info must not be null");
     
    @@ -89,25 +90,26 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s
     		 * user.
     		 */
     		static Async fromSync(Sync syncSpec) {
    -			List tools = new ArrayList<>();
    +			List tools = new ArrayList<>();
     			for (var tool : syncSpec.tools()) {
    -				tools.add(AsyncToolRegistration.fromSync(tool));
    +				tools.add(AsyncToolSpecification.fromSync(tool));
     			}
     
    -			Map resources = new HashMap<>();
    +			Map resources = new HashMap<>();
     			syncSpec.resources().forEach((key, resource) -> {
    -				resources.put(key, AsyncResourceRegistration.fromSync(resource));
    +				resources.put(key, AsyncResourceSpecification.fromSync(resource));
     			});
     
    -			Map prompts = new HashMap<>();
    +			Map prompts = new HashMap<>();
     			syncSpec.prompts().forEach((key, prompt) -> {
    -				prompts.put(key, AsyncPromptRegistration.fromSync(prompt));
    +				prompts.put(key, AsyncPromptSpecification.fromSync(prompt));
     			});
     
    -			List, Mono>> rootChangeConsumers = new ArrayList<>();
    +			List, Mono>> rootChangeConsumers = new ArrayList<>();
     
     			for (var rootChangeConsumer : syncSpec.rootsChangeConsumers()) {
    -				rootChangeConsumers.add(list -> Mono.fromRunnable(() -> rootChangeConsumer.accept(list))
    +				rootChangeConsumers.add((exchange, list) -> Mono
    +					.fromRunnable(() -> rootChangeConsumer.accept(new McpSyncServerExchange(exchange), list))
     					.subscribeOn(Schedulers.boundedElastic()));
     			}
     
    @@ -121,37 +123,37 @@ static Async fromSync(Sync syncSpec) {
     	 *
     	 * @param serverInfo The server implementation details
     	 * @param serverCapabilities The server capabilities
    -	 * @param tools The list of tool registrations
    -	 * @param resources The map of resource registrations
    +	 * @param tools The list of tool specifications
    +	 * @param resources The map of resource specifications
     	 * @param resourceTemplates The list of resource templates
    -	 * @param prompts The map of prompt registrations
    +	 * @param prompts The map of prompt specifications
     	 * @param rootsChangeConsumers The list of consumers that will be notified when the
     	 * roots list changes
     	 */
     	record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities,
    -			List tools,
    -			Map resources,
    +			List tools,
    +			Map resources,
     			List resourceTemplates,
    -			Map prompts,
    -			List>> rootsChangeConsumers) {
    +			Map prompts,
    +			List>> rootsChangeConsumers) {
     
     		/**
     		 * Create an instance and validate the arguments.
     		 * @param serverInfo The server implementation details
     		 * @param serverCapabilities The server capabilities
    -		 * @param tools The list of tool registrations
    -		 * @param resources The map of resource registrations
    +		 * @param tools The list of tool specifications
    +		 * @param resources The map of resource specifications
     		 * @param resourceTemplates The list of resource templates
    -		 * @param prompts The map of prompt registrations
    +		 * @param prompts The map of prompt specifications
     		 * @param rootsChangeConsumers The list of consumers that will be notified when
     		 * the roots list changes
     		 */
     		Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities,
    -				List tools,
    -				Map resources,
    +				List tools,
    +				Map resources,
     				List resourceTemplates,
    -				Map prompts,
    -				List>> rootsChangeConsumers) {
    +				Map prompts,
    +				List>> rootsChangeConsumers) {
     
     			Assert.notNull(serverInfo, "Server info must not be null");
     
    @@ -176,6 +178,255 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se
     
     	}
     
    +	/**
    +	 * Specification of a tool with its asynchronous handler function. Tools are the
    +	 * primary way for MCP servers to expose functionality to AI models. Each tool
    +	 * represents a specific capability, such as:
    +	 * 
      + *
    • Performing calculations + *
    • Accessing external APIs + *
    • Querying databases + *
    • Manipulating files + *
    • Executing system commands + *
    + * + *

    + * Example tool specification:

    {@code
    +	 * new McpServerFeatures.AsyncToolSpecification(
    +	 *     new Tool(
    +	 *         "calculator",
    +	 *         "Performs mathematical calculations",
    +	 *         new JsonSchemaObject()
    +	 *             .required("expression")
    +	 *             .property("expression", JsonSchemaType.STRING)
    +	 *     ),
    +	 *     (exchange, args) -> {
    +	 *         String expr = (String) args.get("expression");
    +	 *         return Mono.fromSupplier(() -> evaluate(expr))
    +	 *             .map(result -> new CallToolResult("Result: " + result));
    +	 *     }
    +	 * )
    +	 * }
    + * + * @param tool The tool definition including name, description, and parameter schema + * @param call The function that implements the tool's logic, receiving arguments and + * returning results. The function's first argument is an + * {@link McpAsyncServerExchange} upon which the server can interact with the + * connected client. The second arguments is a map of tool arguments. + */ + public record AsyncToolSpecification(McpSchema.Tool tool, + BiFunction, Mono> call) { + + static AsyncToolSpecification fromSync(SyncToolSpecification tool) { + // FIXME: This is temporary, proper validation should be implemented + if (tool == null) { + return null; + } + return new AsyncToolSpecification(tool.tool(), + (exchange, map) -> Mono + .fromCallable(() -> tool.call().apply(new McpSyncServerExchange(exchange), map)) + .subscribeOn(Schedulers.boundedElastic())); + } + } + + /** + * Specification of a resource with its asynchronous handler function. Resources + * provide context to AI models by exposing data such as: + *
      + *
    • File contents + *
    • Database records + *
    • API responses + *
    • System information + *
    • Application state + *
    + * + *

    + * Example resource specification:

    {@code
    +	 * new McpServerFeatures.AsyncResourceSpecification(
    +	 *     new Resource("docs", "Documentation files", "text/markdown"),
    +	 *     (exchange, request) ->
    +	 *         Mono.fromSupplier(() -> readFile(request.getPath()))
    +	 *             .map(ReadResourceResult::new)
    +	 * )
    +	 * }
    + * + * @param resource The resource definition including name, description, and MIME type + * @param readHandler The function that handles resource read requests. The function's + * first argument is an {@link McpAsyncServerExchange} upon which the server can + * interact with the connected client. The second arguments is a + * {@link io.modelcontextprotocol.spec.McpSchema.ReadResourceRequest}. + */ + public record AsyncResourceSpecification(McpSchema.Resource resource, + BiFunction> readHandler) { + + static AsyncResourceSpecification fromSync(SyncResourceSpecification resource) { + // FIXME: This is temporary, proper validation should be implemented + if (resource == null) { + return null; + } + return new AsyncResourceSpecification(resource.resource(), + (exchange, req) -> Mono + .fromCallable(() -> resource.readHandler().apply(new McpSyncServerExchange(exchange), req)) + .subscribeOn(Schedulers.boundedElastic())); + } + } + + /** + * Specification of a prompt template with its asynchronous handler function. Prompts + * provide structured templates for AI model interactions, supporting: + *
      + *
    • Consistent message formatting + *
    • Parameter substitution + *
    • Context injection + *
    • Response formatting + *
    • Instruction templating + *
    + * + *

    + * Example prompt specification:

    {@code
    +	 * new McpServerFeatures.AsyncPromptSpecification(
    +	 *     new Prompt("analyze", "Code analysis template"),
    +	 *     (exchange, request) -> {
    +	 *         String code = request.getArguments().get("code");
    +	 *         return Mono.just(new GetPromptResult(
    +	 *             "Analyze this code:\n\n" + code + "\n\nProvide feedback on:"
    +	 *         ));
    +	 *     }
    +	 * )
    +	 * }
    + * + * @param prompt The prompt definition including name and description + * @param promptHandler The function that processes prompt requests and returns + * formatted templates. The function's first argument is an + * {@link McpAsyncServerExchange} upon which the server can interact with the + * connected client. The second arguments is a + * {@link io.modelcontextprotocol.spec.McpSchema.GetPromptRequest}. + */ + public record AsyncPromptSpecification(McpSchema.Prompt prompt, + BiFunction> promptHandler) { + + static AsyncPromptSpecification fromSync(SyncPromptSpecification prompt) { + // FIXME: This is temporary, proper validation should be implemented + if (prompt == null) { + return null; + } + return new AsyncPromptSpecification(prompt.prompt(), + (exchange, req) -> Mono + .fromCallable(() -> prompt.promptHandler().apply(new McpSyncServerExchange(exchange), req)) + .subscribeOn(Schedulers.boundedElastic())); + } + } + + /** + * Specification of a tool with its synchronous handler function. Tools are the + * primary way for MCP servers to expose functionality to AI models. Each tool + * represents a specific capability, such as: + *
      + *
    • Performing calculations + *
    • Accessing external APIs + *
    • Querying databases + *
    • Manipulating files + *
    • Executing system commands + *
    + * + *

    + * Example tool specification:

    {@code
    +	 * new McpServerFeatures.SyncToolSpecification(
    +	 *     new Tool(
    +	 *         "calculator",
    +	 *         "Performs mathematical calculations",
    +	 *         new JsonSchemaObject()
    +	 *             .required("expression")
    +	 *             .property("expression", JsonSchemaType.STRING)
    +	 *     ),
    +	 *     (exchange, args) -> {
    +	 *         String expr = (String) args.get("expression");
    +	 *         return new CallToolResult("Result: " + evaluate(expr));
    +	 *     }
    +	 * )
    +	 * }
    + * + * @param tool The tool definition including name, description, and parameter schema + * @param call The function that implements the tool's logic, receiving arguments and + * returning results. The function's first argument is an + * {@link McpSyncServerExchange} upon which the server can interact with the connected + * client. The second arguments is a map of arguments passed to the tool. + */ + public record SyncToolSpecification(McpSchema.Tool tool, + BiFunction, McpSchema.CallToolResult> call) { + } + + /** + * Specification of a resource with its synchronous handler function. Resources + * provide context to AI models by exposing data such as: + *
      + *
    • File contents + *
    • Database records + *
    • API responses + *
    • System information + *
    • Application state + *
    + * + *

    + * Example resource specification:

    {@code
    +	 * new McpServerFeatures.SyncResourceSpecification(
    +	 *     new Resource("docs", "Documentation files", "text/markdown"),
    +	 *     (exchange, request) -> {
    +	 *         String content = readFile(request.getPath());
    +	 *         return new ReadResourceResult(content);
    +	 *     }
    +	 * )
    +	 * }
    + * + * @param resource The resource definition including name, description, and MIME type + * @param readHandler The function that handles resource read requests. The function's + * first argument is an {@link McpSyncServerExchange} upon which the server can + * interact with the connected client. The second arguments is a + * {@link io.modelcontextprotocol.spec.McpSchema.ReadResourceRequest}. + */ + public record SyncResourceSpecification(McpSchema.Resource resource, + BiFunction readHandler) { + } + + /** + * Specification of a prompt template with its synchronous handler function. Prompts + * provide structured templates for AI model interactions, supporting: + *
      + *
    • Consistent message formatting + *
    • Parameter substitution + *
    • Context injection + *
    • Response formatting + *
    • Instruction templating + *
    + * + *

    + * Example prompt specification:

    {@code
    +	 * new McpServerFeatures.SyncPromptSpecification(
    +	 *     new Prompt("analyze", "Code analysis template"),
    +	 *     (exchange, request) -> {
    +	 *         String code = request.getArguments().get("code");
    +	 *         return new GetPromptResult(
    +	 *             "Analyze this code:\n\n" + code + "\n\nProvide feedback on:"
    +	 *         );
    +	 *     }
    +	 * )
    +	 * }
    + * + * @param prompt The prompt definition including name and description + * @param promptHandler The function that processes prompt requests and returns + * formatted templates. The function's first argument is an + * {@link McpSyncServerExchange} upon which the server can interact with the connected + * client. The second arguments is a + * {@link io.modelcontextprotocol.spec.McpSchema.GetPromptRequest}. + */ + public record SyncPromptSpecification(McpSchema.Prompt prompt, + BiFunction promptHandler) { + } + + // --------------------------------------- + // Deprecated registrations + // --------------------------------------- + /** * Registration of a tool with its asynchronous handler function. Tools are the * primary way for MCP servers to expose functionality to AI models. Each tool @@ -208,7 +459,10 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se * @param tool The tool definition including name, description, and parameter schema * @param call The function that implements the tool's logic, receiving arguments and * returning results + * @deprecated This class is deprecated and will be removed in 0.9.0. Use + * {@link AsyncToolSpecification}. */ + @Deprecated public record AsyncToolRegistration(McpSchema.Tool tool, Function, Mono> call) { @@ -220,6 +474,10 @@ static AsyncToolRegistration fromSync(SyncToolRegistration tool) { return new AsyncToolRegistration(tool.tool(), map -> Mono.fromCallable(() -> tool.call().apply(map)).subscribeOn(Schedulers.boundedElastic())); } + + public AsyncToolSpecification toSpecification() { + return new AsyncToolSpecification(tool(), (exchange, map) -> call.apply(map)); + } } /** @@ -246,7 +504,10 @@ static AsyncToolRegistration fromSync(SyncToolRegistration tool) { * * @param resource The resource definition including name, description, and MIME type * @param readHandler The function that handles resource read requests + * @deprecated This class is deprecated and will be removed in 0.9.0. Use + * {@link AsyncResourceSpecification}. */ + @Deprecated public record AsyncResourceRegistration(McpSchema.Resource resource, Function> readHandler) { @@ -259,6 +520,10 @@ static AsyncResourceRegistration fromSync(SyncResourceRegistration resource) { req -> Mono.fromCallable(() -> resource.readHandler().apply(req)) .subscribeOn(Schedulers.boundedElastic())); } + + public AsyncResourceSpecification toSpecification() { + return new AsyncResourceSpecification(resource(), (exchange, request) -> readHandler.apply(request)); + } } /** @@ -288,7 +553,10 @@ static AsyncResourceRegistration fromSync(SyncResourceRegistration resource) { * @param prompt The prompt definition including name and description * @param promptHandler The function that processes prompt requests and returns * formatted templates + * @deprecated This class is deprecated and will be removed in 0.9.0. Use + * {@link AsyncPromptSpecification}. */ + @Deprecated public record AsyncPromptRegistration(McpSchema.Prompt prompt, Function> promptHandler) { @@ -301,6 +569,10 @@ static AsyncPromptRegistration fromSync(SyncPromptRegistration prompt) { req -> Mono.fromCallable(() -> prompt.promptHandler().apply(req)) .subscribeOn(Schedulers.boundedElastic())); } + + public AsyncPromptSpecification toSpecification() { + return new AsyncPromptSpecification(prompt(), (exchange, request) -> promptHandler.apply(request)); + } } /** @@ -335,9 +607,15 @@ static AsyncPromptRegistration fromSync(SyncPromptRegistration prompt) { * @param tool The tool definition including name, description, and parameter schema * @param call The function that implements the tool's logic, receiving arguments and * returning results + * @deprecated This class is deprecated and will be removed in 0.9.0. Use + * {@link SyncToolSpecification}. */ + @Deprecated public record SyncToolRegistration(McpSchema.Tool tool, Function, McpSchema.CallToolResult> call) { + public SyncToolSpecification toSpecification() { + return new SyncToolSpecification(tool, (exchange, map) -> call.apply(map)); + } } /** @@ -364,9 +642,15 @@ public record SyncToolRegistration(McpSchema.Tool tool, * * @param resource The resource definition including name, description, and MIME type * @param readHandler The function that handles resource read requests + * @deprecated This class is deprecated and will be removed in 0.9.0. Use + * {@link SyncResourceSpecification}. */ + @Deprecated public record SyncResourceRegistration(McpSchema.Resource resource, Function readHandler) { + public SyncResourceSpecification toSpecification() { + return new SyncResourceSpecification(resource, (exchange, request) -> readHandler.apply(request)); + } } /** @@ -396,9 +680,15 @@ public record SyncResourceRegistration(McpSchema.Resource resource, * @param prompt The prompt definition including name and description * @param promptHandler The function that processes prompt requests and returns * formatted templates + * @deprecated This class is deprecated and will be removed in 0.9.0. Use + * {@link SyncPromptSpecification}. */ + @Deprecated public record SyncPromptRegistration(McpSchema.Prompt prompt, Function promptHandler) { + public SyncPromptSpecification toSpecification() { + return new SyncPromptSpecification(prompt, (exchange, request) -> promptHandler.apply(request)); + } } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java index 1de0139b..60662d98 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java @@ -68,7 +68,10 @@ public McpSyncServer(McpAsyncServer asyncServer) { /** * Retrieves the list of all roots provided by the client. * @return The list of roots + * @deprecated This method will be removed in 0.9.0. Use + * {@link McpSyncServerExchange#listRoots()}. */ + @Deprecated public McpSchema.ListRootsResult listRoots() { return this.listRoots(null); } @@ -77,7 +80,10 @@ public McpSchema.ListRootsResult listRoots() { * Retrieves a paginated list of roots provided by the server. * @param cursor Optional pagination cursor from a previous list request * @return The list of roots + * @deprecated This method will be removed in 0.9.0. Use + * {@link McpSyncServerExchange#listRoots(String)}. */ + @Deprecated public McpSchema.ListRootsResult listRoots(String cursor) { return this.asyncServer.listRoots(cursor).block(); } @@ -85,11 +91,22 @@ public McpSchema.ListRootsResult listRoots(String cursor) { /** * Add a new tool handler. * @param toolHandler The tool handler to add + * @deprecated This method will be removed in 0.9.0. Use + * {@link #addTool(McpServerFeatures.SyncToolSpecification)}. */ + @Deprecated public void addTool(McpServerFeatures.SyncToolRegistration toolHandler) { this.asyncServer.addTool(McpServerFeatures.AsyncToolRegistration.fromSync(toolHandler)).block(); } + /** + * Add a new tool handler. + * @param toolHandler The tool handler to add + */ + public void addTool(McpServerFeatures.SyncToolSpecification toolHandler) { + this.asyncServer.addTool(McpServerFeatures.AsyncToolSpecification.fromSync(toolHandler)).block(); + } + /** * Remove a tool handler. * @param toolName The name of the tool handler to remove @@ -101,11 +118,22 @@ public void removeTool(String toolName) { /** * Add a new resource handler. * @param resourceHandler The resource handler to add + * @deprecated This method will be removed in 0.9.0. Use + * {@link #addResource(McpServerFeatures.SyncResourceSpecification)}. */ + @Deprecated public void addResource(McpServerFeatures.SyncResourceRegistration resourceHandler) { this.asyncServer.addResource(McpServerFeatures.AsyncResourceRegistration.fromSync(resourceHandler)).block(); } + /** + * Add a new resource handler. + * @param resourceHandler The resource handler to add + */ + public void addResource(McpServerFeatures.SyncResourceSpecification resourceHandler) { + this.asyncServer.addResource(McpServerFeatures.AsyncResourceSpecification.fromSync(resourceHandler)).block(); + } + /** * Remove a resource handler. * @param resourceUri The URI of the resource handler to remove @@ -117,11 +145,22 @@ public void removeResource(String resourceUri) { /** * Add a new prompt handler. * @param promptRegistration The prompt registration to add + * @deprecated This method will be removed in 0.9.0. Use + * {@link #addPrompt(McpServerFeatures.SyncPromptSpecification)}. */ + @Deprecated public void addPrompt(McpServerFeatures.SyncPromptRegistration promptRegistration) { this.asyncServer.addPrompt(McpServerFeatures.AsyncPromptRegistration.fromSync(promptRegistration)).block(); } + /** + * Add a new prompt handler. + * @param promptSpecification The prompt specification to add + */ + public void addPrompt(McpServerFeatures.SyncPromptSpecification promptSpecification) { + this.asyncServer.addPrompt(McpServerFeatures.AsyncPromptSpecification.fromSync(promptSpecification)).block(); + } + /** * Remove a prompt handler. * @param promptName The name of the prompt handler to remove @@ -156,7 +195,10 @@ public McpSchema.Implementation getServerInfo() { /** * Get the client capabilities that define the supported features and functionality. * @return The client capabilities + * @deprecated This method will be removed in 0.9.0. Use + * {@link McpSyncServerExchange#getClientCapabilities()}. */ + @Deprecated public ClientCapabilities getClientCapabilities() { return this.asyncServer.getClientCapabilities(); } @@ -164,7 +206,10 @@ public ClientCapabilities getClientCapabilities() { /** * Get the client implementation information. * @return The client implementation details + * @deprecated This method will be removed in 0.9.0. Use + * {@link McpSyncServerExchange#getClientInfo()}. */ + @Deprecated public McpSchema.Implementation getClientInfo() { return this.asyncServer.getClientInfo(); } @@ -237,7 +282,10 @@ public McpAsyncServer getAsyncServer() { * @see Sampling * Specification + * @deprecated This method will be removed in 0.9.0. Use + * {@link McpSyncServerExchange#createMessage(McpSchema.CreateMessageRequest)}. */ + @Deprecated public McpSchema.CreateMessageResult createMessage(McpSchema.CreateMessageRequest createMessageRequest) { return this.asyncServer.createMessage(createMessageRequest).block(); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java new file mode 100644 index 00000000..f121db55 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java @@ -0,0 +1,78 @@ +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.spec.McpSchema; + +/** + * Represents a synchronous exchange with a Model Context Protocol (MCP) client. The + * exchange provides methods to interact with the client and query its capabilities. + * + * @author Dariusz Jędrzejczyk + */ +public class McpSyncServerExchange { + + private final McpAsyncServerExchange exchange; + + /** + * Create a new synchronous exchange with the client using the provided asynchronous + * implementation as a delegate. + * @param exchange The asynchronous exchange to delegate to. + */ + public McpSyncServerExchange(McpAsyncServerExchange exchange) { + this.exchange = exchange; + } + + /** + * Get the client capabilities that define the supported features and functionality. + * @return The client capabilities + */ + public McpSchema.ClientCapabilities getClientCapabilities() { + return this.exchange.getClientCapabilities(); + } + + /** + * Get the client implementation information. + * @return The client implementation details + */ + public McpSchema.Implementation getClientInfo() { + return this.exchange.getClientInfo(); + } + + /** + * Create a new message using the sampling capabilities of the client. The Model + * Context Protocol (MCP) provides a standardized way for servers to request LLM + * sampling (“completions” or “generations”) from language models via clients. This + * flow allows clients to maintain control over model access, selection, and + * permissions while enabling servers to leverage AI capabilities—with no server API + * keys necessary. Servers can request text or image-based interactions and optionally + * include context from MCP servers in their prompts. + * @param createMessageRequest The request to create a new message + * @return A result containing the details of the sampling response + * @see McpSchema.CreateMessageRequest + * @see McpSchema.CreateMessageResult + * @see Sampling + * Specification + */ + public McpSchema.CreateMessageResult createMessage(McpSchema.CreateMessageRequest createMessageRequest) { + return this.exchange.createMessage(createMessageRequest).block(); + } + + /** + * Retrieves the list of all roots provided by the client. + * @return The list of roots result. + */ + public McpSchema.ListRootsResult listRoots() { + return this.exchange.listRoots().block(); + } + + /** + * Retrieves a paginated list of roots provided by the client. + * @param cursor Optional pagination cursor from a previous list request + * @return The list of roots result + */ + public McpSchema.ListRootsResult listRoots(String cursor) { + return this.exchange.listRoots(cursor).block(); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransport.java index 98b8ea58..fa5dcf1c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransport.java @@ -32,6 +32,9 @@ * specification. This implementation provides similar functionality to * WebFluxSseServerTransport but uses the traditional Servlet API instead of WebFlux. * + * @deprecated This class will be removed in 0.9.0. Use + * {@link HttpServletSseServerTransportProvider}. + * *

    * The transport handles two types of endpoints: *

      @@ -48,7 +51,6 @@ *
    • Graceful shutdown support
    • *
    • Error handling and response formatting
    • *
    - * * @author Christian Tzolov * @author Alexandros Pappas * @see ServerMcpTransport @@ -56,6 +58,7 @@ */ @WebServlet(asyncSupported = true) +@Deprecated public class HttpServletSseServerTransport extends HttpServlet implements ServerMcpTransport { /** Logger for this class */ diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java new file mode 100644 index 00000000..152462b1 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -0,0 +1,432 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server.transport; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.PrintWriter; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.ServletException; +import jakarta.servlet.annotation.WebServlet; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * A Servlet-based implementation of the MCP HTTP with Server-Sent Events (SSE) transport + * specification. This implementation provides similar functionality to + * WebFluxSseServerTransportProvider but uses the traditional Servlet API instead of + * WebFlux. + * + *

    + * The transport handles two types of endpoints: + *

      + *
    • SSE endpoint (/sse) - Establishes a long-lived connection for server-to-client + * events
    • + *
    • Message endpoint (configurable) - Handles client-to-server message requests
    • + *
    + * + *

    + * Features: + *

      + *
    • Asynchronous message handling using Servlet 6.0 async support
    • + *
    • Session management for multiple client connections
    • + *
    • Graceful shutdown support
    • + *
    • Error handling and response formatting
    • + *
    + * + * @author Christian Tzolov + * @author Alexandros Pappas + * @see McpServerTransportProvider + * @see HttpServlet + */ + +@WebServlet(asyncSupported = true) +public class HttpServletSseServerTransportProvider extends HttpServlet implements McpServerTransportProvider { + + /** Logger for this class */ + private static final Logger logger = LoggerFactory.getLogger(HttpServletSseServerTransportProvider.class); + + public static final String UTF_8 = "UTF-8"; + + public static final String APPLICATION_JSON = "application/json"; + + public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}"; + + /** Default endpoint path for SSE connections */ + public static final String DEFAULT_SSE_ENDPOINT = "/sse"; + + /** Event type for regular messages */ + public static final String MESSAGE_EVENT_TYPE = "message"; + + /** Event type for endpoint information */ + public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + /** JSON object mapper for serialization/deserialization */ + private final ObjectMapper objectMapper; + + /** The endpoint path for handling client messages */ + private final String messageEndpoint; + + /** The endpoint path for handling SSE connections */ + private final String sseEndpoint; + + /** Map of active client sessions, keyed by session ID */ + private final Map sessions = new ConcurrentHashMap<>(); + + /** Flag indicating if the transport is in the process of shutting down */ + private final AtomicBoolean isClosing = new AtomicBoolean(false); + + /** Session factory for creating new sessions */ + private McpServerSession.Factory sessionFactory; + + /** + * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE + * endpoint. + * @param objectMapper The JSON object mapper to use for message + * serialization/deserialization + * @param messageEndpoint The endpoint path where clients will send their messages + * @param sseEndpoint The endpoint path where clients will establish SSE connections + */ + public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, + String sseEndpoint) { + this.objectMapper = objectMapper; + this.messageEndpoint = messageEndpoint; + this.sseEndpoint = sseEndpoint; + } + + /** + * Creates a new HttpServletSseServerTransportProvider instance with the default SSE + * endpoint. + * @param objectMapper The JSON object mapper to use for message + * serialization/deserialization + * @param messageEndpoint The endpoint path where clients will send their messages + */ + public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { + this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); + } + + /** + * Sets the session factory for creating new sessions. + * @param sessionFactory The session factory to use + */ + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + /** + * Broadcasts a notification to all connected clients. + * @param method The method name for the notification + * @param params The parameters for the notification + * @return A Mono that completes when the broadcast attempt is finished + */ + @Override + public Mono notifyClients(String method, Map params) { + if (sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + + return Flux.fromIterable(sessions.values()) + .flatMap(session -> session.sendNotification(method, params) + .doOnError( + e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())) + .onErrorComplete()) + .then(); + } + + /** + * Handles GET requests to establish SSE connections. + *

    + * This method sets up a new SSE connection when a client connects to the SSE + * endpoint. It configures the response headers for SSE, creates a new session, and + * sends the initial endpoint information to the client. + * @param request The HTTP servlet request + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String pathInfo = request.getPathInfo(); + if (!sseEndpoint.equals(pathInfo)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + if (isClosing.get()) { + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return; + } + + response.setContentType("text/event-stream"); + response.setCharacterEncoding(UTF_8); + response.setHeader("Cache-Control", "no-cache"); + response.setHeader("Connection", "keep-alive"); + response.setHeader("Access-Control-Allow-Origin", "*"); + + String sessionId = UUID.randomUUID().toString(); + AsyncContext asyncContext = request.startAsync(); + asyncContext.setTimeout(0); + + PrintWriter writer = response.getWriter(); + + // Create a new session transport + HttpServletMcpSessionTransport sessionTransport = new HttpServletMcpSessionTransport(sessionId, asyncContext, + writer); + + // Create a new session using the session factory + McpServerSession session = sessionFactory.create(sessionTransport); + this.sessions.put(sessionId, session); + + // Send initial endpoint event + this.sendEvent(writer, ENDPOINT_EVENT_TYPE, messageEndpoint + "?sessionId=" + sessionId); + } + + /** + * Handles POST requests for client messages. + *

    + * This method processes incoming messages from clients, routes them through the + * session handler, and sends back the appropriate response. It handles error cases + * and formats error responses according to the MCP specification. + * @param request The HTTP servlet request + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doPost(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + if (isClosing.get()) { + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return; + } + + String pathInfo = request.getPathInfo(); + if (!messageEndpoint.equals(pathInfo)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + // Get the session ID from the request parameter + String sessionId = request.getParameter("sessionId"); + if (sessionId == null) { + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(HttpServletResponse.SC_BAD_REQUEST); + String jsonError = objectMapper.writeValueAsString(new McpError("Session ID missing in message endpoint")); + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + return; + } + + // Get the session from the sessions map + McpServerSession session = sessions.get(sessionId); + if (session == null) { + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(HttpServletResponse.SC_NOT_FOUND); + String jsonError = objectMapper.writeValueAsString(new McpError("Session not found: " + sessionId)); + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + return; + } + + try { + BufferedReader reader = request.getReader(); + StringBuilder body = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + body.append(line); + } + + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString()); + + // Process the message through the session's handle method + session.handle(message).block(); // Block for Servlet compatibility + + response.setStatus(HttpServletResponse.SC_OK); + } + catch (Exception e) { + logger.error("Error processing message: {}", e.getMessage()); + try { + McpError mcpError = new McpError(e.getMessage()); + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); + String jsonError = objectMapper.writeValueAsString(mcpError); + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + } + catch (IOException ex) { + logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage()); + response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Error processing message"); + } + } + } + + /** + * Initiates a graceful shutdown of the transport. + *

    + * This method marks the transport as closing and closes all active client sessions. + * New connection attempts will be rejected during shutdown. + * @return A Mono that completes when all sessions have been closed + */ + @Override + public Mono closeGracefully() { + isClosing.set(true); + logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); + + return Flux.fromIterable(sessions.values()).flatMap(McpServerSession::closeGracefully).then(); + } + + /** + * Sends an SSE event to a client. + * @param writer The writer to send the event through + * @param eventType The type of event (message or endpoint) + * @param data The event data + * @throws IOException If an error occurs while writing the event + */ + private void sendEvent(PrintWriter writer, String eventType, String data) throws IOException { + writer.write("event: " + eventType + "\n"); + writer.write("data: " + data + "\n\n"); + writer.flush(); + + if (writer.checkError()) { + throw new IOException("Client disconnected"); + } + } + + /** + * Cleans up resources when the servlet is being destroyed. + *

    + * This method ensures a graceful shutdown by closing all client connections before + * calling the parent's destroy method. + */ + @Override + public void destroy() { + closeGracefully().block(); + super.destroy(); + } + + /** + * Implementation of McpServerTransport for HttpServlet SSE sessions. This class + * handles the transport-level communication for a specific client session. + */ + private class HttpServletMcpSessionTransport implements McpServerTransport { + + private final String sessionId; + + private final AsyncContext asyncContext; + + private final PrintWriter writer; + + /** + * Creates a new session transport with the specified ID and SSE writer. + * @param sessionId The unique identifier for this session + * @param asyncContext The async context for the session + * @param writer The writer for sending server events to the client + */ + HttpServletMcpSessionTransport(String sessionId, AsyncContext asyncContext, PrintWriter writer) { + this.sessionId = sessionId; + this.asyncContext = asyncContext; + this.writer = writer; + logger.debug("Session transport {} initialized with SSE writer", sessionId); + } + + /** + * Sends a JSON-RPC message to the client through the SSE connection. + * @param message The JSON-RPC message to send + * @return A Mono that completes when the message has been sent + */ + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return Mono.fromRunnable(() -> { + try { + String jsonText = objectMapper.writeValueAsString(message); + sendEvent(writer, MESSAGE_EVENT_TYPE, jsonText); + logger.debug("Message sent to session {}", sessionId); + } + catch (Exception e) { + logger.error("Failed to send message to session {}: {}", sessionId, e.getMessage()); + sessions.remove(sessionId); + asyncContext.complete(); + } + }); + } + + /** + * Converts data from one type to another using the configured ObjectMapper. + * @param data The source data object to convert + * @param typeRef The target type reference + * @return The converted object of type T + * @param The target type + */ + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + /** + * Initiates a graceful shutdown of the transport. + * @return A Mono that completes when the shutdown is complete + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + logger.debug("Closing session transport: {}", sessionId); + try { + sessions.remove(sessionId); + asyncContext.complete(); + logger.debug("Successfully completed async context for session {}", sessionId); + } + catch (Exception e) { + logger.warn("Failed to complete async context for session {}: {}", sessionId, e.getMessage()); + } + }); + } + + /** + * Closes the transport immediately. + */ + @Override + public void close() { + try { + sessions.remove(sessionId); + asyncContext.complete(); + logger.debug("Successfully completed async context for session {}", sessionId); + } + catch (Exception e) { + logger.warn("Failed to complete async context for session {}: {}", sessionId, e.getMessage()); + } + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransport.java index e375cd10..78264ca3 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransport.java @@ -33,7 +33,10 @@ * over stdin/stdout, with errors and debug information sent to stderr. * * @author Christian Tzolov + * @deprecated This method will be removed in 0.9.0. Use + * {@link io.modelcontextprotocol.server.transport.StdioServerTransportProvider} instead. */ +@Deprecated public class StdioServerTransport implements ServerMcpTransport { private static final Logger logger = LoggerFactory.getLogger(StdioServerTransport.class); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java new file mode 100644 index 00000000..6a7d2903 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java @@ -0,0 +1,306 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.io.Reader; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +/** + * Implementation of the MCP Stdio transport provider for servers that communicates using + * standard input/output streams. Messages are exchanged as newline-delimited JSON-RPC + * messages over stdin/stdout, with errors and debug information sent to stderr. + * + * @author Christian Tzolov + */ +public class StdioServerTransportProvider implements McpServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(StdioServerTransportProvider.class); + + private final ObjectMapper objectMapper; + + private final InputStream inputStream; + + private final OutputStream outputStream; + + private McpServerSession session; + + private final AtomicBoolean isClosing = new AtomicBoolean(false); + + private final Sinks.One inboundReady = Sinks.one(); + + /** + * Creates a new StdioServerTransportProvider with a default ObjectMapper and System + * streams. + */ + public StdioServerTransportProvider() { + this(new ObjectMapper()); + } + + /** + * Creates a new StdioServerTransportProvider with the specified ObjectMapper and + * System streams. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + */ + public StdioServerTransportProvider(ObjectMapper objectMapper) { + this(objectMapper, System.in, System.out); + } + + /** + * Creates a new StdioServerTransportProvider with the specified ObjectMapper and + * streams. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * @param inputStream The input stream to read from + * @param outputStream The output stream to write to + */ + public StdioServerTransportProvider(ObjectMapper objectMapper, InputStream inputStream, OutputStream outputStream) { + Assert.notNull(objectMapper, "The ObjectMapper can not be null"); + Assert.notNull(inputStream, "The InputStream can not be null"); + Assert.notNull(outputStream, "The OutputStream can not be null"); + + this.objectMapper = objectMapper; + this.inputStream = inputStream; + this.outputStream = outputStream; + } + + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + // Create a single session for the stdio connection + this.session = sessionFactory.create(new StdioMcpSessionTransport()); + } + + @Override + public Mono notifyClients(String method, Map params) { + if (this.session == null) { + return Mono.error(new McpError("No session to close")); + } + return this.session.sendNotification(method, params) + .doOnError(e -> logger.error("Failed to send notification: {}", e.getMessage())); + } + + @Override + public Mono closeGracefully() { + if (this.session == null) { + return Mono.empty(); + } + return this.session.closeGracefully(); + } + + /** + * Implementation of McpServerTransport for the stdio session. + */ + private class StdioMcpSessionTransport implements McpServerTransport { + + private final Sinks.Many inboundSink; + + private final Sinks.Many outboundSink; + + private final AtomicBoolean isStarted = new AtomicBoolean(false); + + /** Scheduler for handling inbound messages */ + private Scheduler inboundScheduler; + + /** Scheduler for handling outbound messages */ + private Scheduler outboundScheduler; + + private final Sinks.One outboundReady = Sinks.one(); + + public StdioMcpSessionTransport() { + + this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); + this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); + + // Use bounded schedulers for better resource management + this.inboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), + "stdio-inbound"); + this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), + "stdio-outbound"); + + handleIncomingMessages(); + startInboundProcessing(); + startOutboundProcessing(); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + + return Mono.zip(inboundReady.asMono(), outboundReady.asMono()).then(Mono.defer(() -> { + if (outboundSink.tryEmitNext(message).isSuccess()) { + return Mono.empty(); + } + else { + return Mono.error(new RuntimeException("Failed to enqueue message")); + } + })); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + isClosing.set(true); + logger.debug("Session transport closing gracefully"); + inboundSink.tryEmitComplete(); + }); + } + + @Override + public void close() { + isClosing.set(true); + logger.debug("Session transport closed"); + } + + private void handleIncomingMessages() { + this.inboundSink.asFlux().flatMap(message -> session.handle(message)).doOnTerminate(() -> { + // The outbound processing will dispose its scheduler upon completion + this.outboundSink.tryEmitComplete(); + this.inboundScheduler.dispose(); + }).subscribe(); + } + + /** + * Starts the inbound processing thread that reads JSON-RPC messages from stdin. + * Messages are deserialized and passed to the session for handling. + */ + private void startInboundProcessing() { + if (isStarted.compareAndSet(false, true)) { + this.inboundScheduler.schedule(() -> { + inboundReady.tryEmitValue(null); + BufferedReader reader = null; + try { + reader = new BufferedReader(new InputStreamReader(inputStream)); + while (!isClosing.get()) { + try { + String line = reader.readLine(); + if (line == null || isClosing.get()) { + break; + } + + logger.debug("Received JSON message: {}", line); + + try { + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, + line); + if (!this.inboundSink.tryEmitNext(message).isSuccess()) { + // logIfNotClosing("Failed to enqueue message"); + break; + } + + } + catch (Exception e) { + logIfNotClosing("Error processing inbound message", e); + break; + } + } + catch (IOException e) { + logIfNotClosing("Error reading from stdin", e); + break; + } + } + } + catch (Exception e) { + logIfNotClosing("Error in inbound processing", e); + } + finally { + isClosing.set(true); + if (session != null) { + session.close(); + } + inboundSink.tryEmitComplete(); + } + }); + } + } + + /** + * Starts the outbound processing thread that writes JSON-RPC messages to stdout. + * Messages are serialized to JSON and written with a newline delimiter. + */ + private void startOutboundProcessing() { + Function, Flux> outboundConsumer = messages -> messages // @formatter:off + .doOnSubscribe(subscription -> outboundReady.tryEmitValue(null)) + .publishOn(outboundScheduler) + .handle((message, sink) -> { + if (message != null && !isClosing.get()) { + try { + String jsonMessage = objectMapper.writeValueAsString(message); + // Escape any embedded newlines in the JSON message as per spec + jsonMessage = jsonMessage.replace("\r\n", "\\n").replace("\n", "\\n").replace("\r", "\\n"); + + synchronized (outputStream) { + outputStream.write(jsonMessage.getBytes(StandardCharsets.UTF_8)); + outputStream.write("\n".getBytes(StandardCharsets.UTF_8)); + outputStream.flush(); + } + sink.next(message); + } + catch (IOException e) { + if (!isClosing.get()) { + logger.error("Error writing message", e); + sink.error(new RuntimeException(e)); + } + else { + logger.debug("Stream closed during shutdown", e); + } + } + } + else if (isClosing.get()) { + sink.complete(); + } + }) + .doOnComplete(() -> { + isClosing.set(true); + outboundScheduler.dispose(); + }) + .doOnError(e -> { + if (!isClosing.get()) { + logger.error("Error in outbound processing", e); + isClosing.set(true); + outboundScheduler.dispose(); + } + }) + .map(msg -> (JSONRPCMessage) msg); + + outboundConsumer.apply(outboundSink.asFlux()).subscribe(); + } // @formatter:on + + private void logIfNotClosing(String message, Exception e) { + if (!isClosing.get()) { + logger.error(message, e); + } + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java index 8a9b4ce0..8464b6ae 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/ClientMcpTransport.java @@ -7,7 +7,9 @@ * Marker interface for the client-side MCP transport. * * @author Christian Tzolov + * @deprecated This class will be removed in 0.9.0. Use {@link McpClientTransport}. */ +@Deprecated public interface ClientMcpTransport extends McpTransport { } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java index 46aefafc..83de4c09 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpSession.java @@ -34,7 +34,10 @@ * * @author Christian Tzolov * @author Dariusz Jędrzejczyk + * @deprecated This method will be removed in 0.9.0. Use {@link McpClientSession} instead */ +@Deprecated + public class DefaultMcpSession implements McpSession { /** Logger for this class */ diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java new file mode 100644 index 00000000..6657e362 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -0,0 +1,288 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.time.Duration; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoSink; + +/** + * Default implementation of the MCP (Model Context Protocol) session that manages + * bidirectional JSON-RPC communication between clients and servers. This implementation + * follows the MCP specification for message exchange and transport handling. + * + *

    + * The session manages: + *

      + *
    • Request/response handling with unique message IDs
    • + *
    • Notification processing
    • + *
    • Message timeout management
    • + *
    • Transport layer abstraction
    • + *
    + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +public class McpClientSession implements McpSession { + + /** Logger for this class */ + private static final Logger logger = LoggerFactory.getLogger(McpClientSession.class); + + /** Duration to wait for request responses before timing out */ + private final Duration requestTimeout; + + /** Transport layer implementation for message exchange */ + private final McpTransport transport; + + /** Map of pending responses keyed by request ID */ + private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); + + /** Map of request handlers keyed by method name */ + private final ConcurrentHashMap> requestHandlers = new ConcurrentHashMap<>(); + + /** Map of notification handlers keyed by method name */ + private final ConcurrentHashMap notificationHandlers = new ConcurrentHashMap<>(); + + /** Session-specific prefix for request IDs */ + private final String sessionPrefix = UUID.randomUUID().toString().substring(0, 8); + + /** Atomic counter for generating unique request IDs */ + private final AtomicLong requestCounter = new AtomicLong(0); + + private final Disposable connection; + + /** + * Functional interface for handling incoming JSON-RPC requests. Implementations + * should process the request parameters and return a response. + * + * @param Response type + */ + @FunctionalInterface + public interface RequestHandler { + + /** + * Handles an incoming request with the given parameters. + * @param params The request parameters + * @return A Mono containing the response object + */ + Mono handle(Object params); + + } + + /** + * Functional interface for handling incoming JSON-RPC notifications. Implementations + * should process the notification parameters without returning a response. + */ + @FunctionalInterface + public interface NotificationHandler { + + /** + * Handles an incoming notification with the given parameters. + * @param params The notification parameters + * @return A Mono that completes when the notification is processed + */ + Mono handle(Object params); + + } + + /** + * Creates a new McpClientSession with the specified configuration and handlers. + * @param requestTimeout Duration to wait for responses + * @param transport Transport implementation for message exchange + * @param requestHandlers Map of method names to request handlers + * @param notificationHandlers Map of method names to notification handlers + */ + public McpClientSession(Duration requestTimeout, McpTransport transport, + Map> requestHandlers, Map notificationHandlers) { + + Assert.notNull(requestTimeout, "The requstTimeout can not be null"); + Assert.notNull(transport, "The transport can not be null"); + Assert.notNull(requestHandlers, "The requestHandlers can not be null"); + Assert.notNull(notificationHandlers, "The notificationHandlers can not be null"); + + this.requestTimeout = requestTimeout; + this.transport = transport; + this.requestHandlers.putAll(requestHandlers); + this.notificationHandlers.putAll(notificationHandlers); + + // TODO: consider mono.transformDeferredContextual where the Context contains + // the + // Observation associated with the individual message - it can be used to + // create child Observation and emit it together with the message to the + // consumer + this.connection = this.transport.connect(mono -> mono.doOnNext(message -> { + if (message instanceof McpSchema.JSONRPCResponse response) { + logger.debug("Received Response: {}", response); + var sink = pendingResponses.remove(response.id()); + if (sink == null) { + logger.warn("Unexpected response for unkown id {}", response.id()); + } + else { + sink.success(response); + } + } + else if (message instanceof McpSchema.JSONRPCRequest request) { + logger.debug("Received request: {}", request); + handleIncomingRequest(request).subscribe(response -> transport.sendMessage(response).subscribe(), + error -> { + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), + null, new McpSchema.JSONRPCResponse.JSONRPCError( + McpSchema.ErrorCodes.INTERNAL_ERROR, error.getMessage(), null)); + transport.sendMessage(errorResponse).subscribe(); + }); + } + else if (message instanceof McpSchema.JSONRPCNotification notification) { + logger.debug("Received notification: {}", notification); + handleIncomingNotification(notification).subscribe(null, + error -> logger.error("Error handling notification: {}", error.getMessage())); + } + })).subscribe(); + } + + /** + * Handles an incoming JSON-RPC request by routing it to the appropriate handler. + * @param request The incoming JSON-RPC request + * @return A Mono containing the JSON-RPC response + */ + private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request) { + return Mono.defer(() -> { + var handler = this.requestHandlers.get(request.method()); + if (handler == null) { + MethodNotFoundError error = getMethodNotFoundError(request.method()); + return Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, + error.message(), error.data()))); + } + + return handler.handle(request.params()) + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) + .onErrorResume(error -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), + null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), null)))); // TODO: add error message + // through the data field + }); + } + + record MethodNotFoundError(String method, String message, Object data) { + } + + public static MethodNotFoundError getMethodNotFoundError(String method) { + switch (method) { + case McpSchema.METHOD_ROOTS_LIST: + return new MethodNotFoundError(method, "Roots not supported", + Map.of("reason", "Client does not have roots capability")); + default: + return new MethodNotFoundError(method, "Method not found: " + method, null); + } + } + + /** + * Handles an incoming JSON-RPC notification by routing it to the appropriate handler. + * @param notification The incoming JSON-RPC notification + * @return A Mono that completes when the notification is processed + */ + private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification) { + return Mono.defer(() -> { + var handler = notificationHandlers.get(notification.method()); + if (handler == null) { + logger.error("No handler registered for notification method: {}", notification.method()); + return Mono.empty(); + } + return handler.handle(notification.params()); + }); + } + + /** + * Generates a unique request ID in a non-blocking way. Combines a session-specific + * prefix with an atomic counter to ensure uniqueness. + * @return A unique request ID string + */ + private String generateRequestId() { + return this.sessionPrefix + "-" + this.requestCounter.getAndIncrement(); + } + + /** + * Sends a JSON-RPC request and returns the response. + * @param The expected response type + * @param method The method name to call + * @param requestParams The request parameters + * @param typeRef Type reference for response deserialization + * @return A Mono containing the response + */ + @Override + public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + String requestId = this.generateRequestId(); + + return Mono.create(sink -> { + this.pendingResponses.put(requestId, sink); + McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, + requestId, requestParams); + this.transport.sendMessage(jsonrpcRequest) + // TODO: It's most efficient to create a dedicated Subscriber here + .subscribe(v -> { + }, error -> { + this.pendingResponses.remove(requestId); + sink.error(error); + }); + }).timeout(this.requestTimeout).handle((jsonRpcResponse, sink) -> { + if (jsonRpcResponse.error() != null) { + sink.error(new McpError(jsonRpcResponse.error())); + } + else { + if (typeRef.getType().equals(Void.class)) { + sink.complete(); + } + else { + sink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef)); + } + } + }); + } + + /** + * Sends a JSON-RPC notification. + * @param method The method name for the notification + * @param params The notification parameters + * @return A Mono that completes when the notification is sent + */ + @Override + public Mono sendNotification(String method, Map params) { + McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + method, params); + return this.transport.sendMessage(jsonrpcNotification); + } + + /** + * Closes the session gracefully, allowing pending operations to complete. + * @return A Mono that completes when the session is closed + */ + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + this.connection.dispose(); + return transport.closeGracefully(); + }); + } + + /** + * Closes the session immediately, potentially interrupting pending operations. + */ + @Override + public void close() { + this.connection.dispose(); + transport.close(); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java new file mode 100644 index 00000000..45897965 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java @@ -0,0 +1,21 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +*/ +package io.modelcontextprotocol.spec; + +import java.util.function.Function; + +import reactor.core.publisher.Mono; + +/** + * Marker interface for the client-side MCP transport. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +public interface McpClientTransport extends ClientMcpTransport { + + @Override + Mono connect(Function, Mono> handler); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java new file mode 100644 index 00000000..bcdf2248 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -0,0 +1,354 @@ +package io.modelcontextprotocol.spec; + +import java.time.Duration; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoSink; +import reactor.core.publisher.Sinks; + +/** + * Represents a Model Control Protocol (MCP) session on the server side. It manages + * bidirectional JSON-RPC communication with the client. + */ +public class McpServerSession implements McpSession { + + private static final Logger logger = LoggerFactory.getLogger(McpServerSession.class); + + private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); + + private final String id; + + private final AtomicLong requestCounter = new AtomicLong(0); + + private final InitRequestHandler initRequestHandler; + + private final InitNotificationHandler initNotificationHandler; + + private final Map> requestHandlers; + + private final Map notificationHandlers; + + private final McpServerTransport transport; + + private final Sinks.One exchangeSink = Sinks.one(); + + private final AtomicReference clientCapabilities = new AtomicReference<>(); + + private final AtomicReference clientInfo = new AtomicReference<>(); + + private static final int STATE_UNINITIALIZED = 0; + + private static final int STATE_INITIALIZING = 1; + + private static final int STATE_INITIALIZED = 2; + + private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED); + + /** + * Creates a new server session with the given parameters and the transport to use. + * @param id session id + * @param transport the transport to use + * @param initHandler called when a + * {@link io.modelcontextprotocol.spec.McpSchema.InitializeRequest} is received by the + * server + * @param initNotificationHandler called when a + * {@link McpSchema.METHOD_NOTIFICATION_INITIALIZED} is received. + * @param requestHandlers map of request handlers to use + * @param notificationHandlers map of notification handlers to use + */ + public McpServerSession(String id, McpServerTransport transport, InitRequestHandler initHandler, + InitNotificationHandler initNotificationHandler, Map> requestHandlers, + Map notificationHandlers) { + this.id = id; + this.transport = transport; + this.initRequestHandler = initHandler; + this.initNotificationHandler = initNotificationHandler; + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + } + + /** + * Retrieve the session id. + * @return session id + */ + public String getId() { + return this.id; + } + + /** + * Called upon successful initialization sequence between the client and the server + * with the client capabilities and information. + * + * Initialization + * Spec + * @param clientCapabilities the capabilities the connected client provides + * @param clientInfo the information about the connected client + */ + public void init(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { + this.clientCapabilities.lazySet(clientCapabilities); + this.clientInfo.lazySet(clientInfo); + } + + private String generateRequestId() { + return this.id + "-" + this.requestCounter.getAndIncrement(); + } + + @Override + public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + String requestId = this.generateRequestId(); + + return Mono.create(sink -> { + this.pendingResponses.put(requestId, sink); + McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, + requestId, requestParams); + this.transport.sendMessage(jsonrpcRequest).subscribe(v -> { + }, error -> { + this.pendingResponses.remove(requestId); + sink.error(error); + }); + }).timeout(Duration.ofSeconds(10)).handle((jsonRpcResponse, sink) -> { + if (jsonRpcResponse.error() != null) { + sink.error(new McpError(jsonRpcResponse.error())); + } + else { + if (typeRef.getType().equals(Void.class)) { + sink.complete(); + } + else { + sink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef)); + } + } + }); + } + + @Override + public Mono sendNotification(String method, Map params) { + McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + method, params); + return this.transport.sendMessage(jsonrpcNotification); + } + + /** + * Called by the {@link McpServerTransportProvider} once the session is determined. + * The purpose of this method is to dispatch the message to an appropriate handler as + * specified by the MCP server implementation + * ({@link io.modelcontextprotocol.server.McpAsyncServer} or + * {@link io.modelcontextprotocol.server.McpSyncServer}) via + * {@link McpServerSession.Factory} that the server creates. + * @param message the incoming JSON-RPC message + * @return a Mono that completes when the message is processed + */ + public Mono handle(McpSchema.JSONRPCMessage message) { + return Mono.defer(() -> { + // TODO handle errors for communication to without initialization happening + // first + if (message instanceof McpSchema.JSONRPCResponse response) { + logger.debug("Received Response: {}", response); + var sink = pendingResponses.remove(response.id()); + if (sink == null) { + logger.warn("Unexpected response for unknown id {}", response.id()); + } + else { + sink.success(response); + } + return Mono.empty(); + } + else if (message instanceof McpSchema.JSONRPCRequest request) { + logger.debug("Received request: {}", request); + return handleIncomingRequest(request).onErrorResume(error -> { + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), null)); + // TODO: Should the error go to SSE or back as POST return? + return this.transport.sendMessage(errorResponse).then(Mono.empty()); + }).flatMap(this.transport::sendMessage); + } + else if (message instanceof McpSchema.JSONRPCNotification notification) { + // TODO handle errors for communication to without initialization + // happening first + logger.debug("Received notification: {}", notification); + // TODO: in case of error, should the POST request be signalled? + return handleIncomingNotification(notification) + .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage())); + } + else { + logger.warn("Received unknown message type: {}", message); + return Mono.empty(); + } + }); + } + + /** + * Handles an incoming JSON-RPC request by routing it to the appropriate handler. + * @param request The incoming JSON-RPC request + * @return A Mono containing the JSON-RPC response + */ + private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request) { + return Mono.defer(() -> { + Mono resultMono; + if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { + // TODO handle situation where already initialized! + McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(request.params(), + new TypeReference() { + }); + + this.state.lazySet(STATE_INITIALIZING); + this.init(initializeRequest.capabilities(), initializeRequest.clientInfo()); + resultMono = this.initRequestHandler.handle(initializeRequest); + } + else { + // TODO handle errors for communication to this session without + // initialization happening first + var handler = this.requestHandlers.get(request.method()); + if (handler == null) { + MethodNotFoundError error = getMethodNotFoundError(request.method()); + return Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, + error.message(), error.data()))); + } + + resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, request.params())); + } + return resultMono + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) + .onErrorResume(error -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), + null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), null)))); // TODO: add error message + // through the data field + }); + } + + /** + * Handles an incoming JSON-RPC notification by routing it to the appropriate handler. + * @param notification The incoming JSON-RPC notification + * @return A Mono that completes when the notification is processed + */ + private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification) { + return Mono.defer(() -> { + if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) { + this.state.lazySet(STATE_INITIALIZED); + exchangeSink.tryEmitValue(new McpAsyncServerExchange(this, clientCapabilities.get(), clientInfo.get())); + return this.initNotificationHandler.handle(); + } + + var handler = notificationHandlers.get(notification.method()); + if (handler == null) { + logger.error("No handler registered for notification method: {}", notification.method()); + return Mono.empty(); + } + return this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, notification.params())); + }); + } + + record MethodNotFoundError(String method, String message, Object data) { + } + + static MethodNotFoundError getMethodNotFoundError(String method) { + switch (method) { + case McpSchema.METHOD_ROOTS_LIST: + return new MethodNotFoundError(method, "Roots not supported", + Map.of("reason", "Client does not have roots capability")); + default: + return new MethodNotFoundError(method, "Method not found: " + method, null); + } + } + + @Override + public Mono closeGracefully() { + return this.transport.closeGracefully(); + } + + @Override + public void close() { + this.transport.close(); + } + + /** + * Request handler for the initialization request. + */ + public interface InitRequestHandler { + + /** + * Handles the initialization request. + * @param initializeRequest the initialization request by the client + * @return a Mono that will emit the result of the initialization + */ + Mono handle(McpSchema.InitializeRequest initializeRequest); + + } + + /** + * Notification handler for the initialization notification from the client. + */ + public interface InitNotificationHandler { + + /** + * Specifies an action to take upon successful initialization. + * @return a Mono that will complete when the initialization is acted upon. + */ + Mono handle(); + + } + + /** + * A handler for client-initiated notifications. + */ + public interface NotificationHandler { + + /** + * Handles a notification from the client. + * @param exchange the exchange associated with the client that allows calling + * back to the connected client or inspecting its capabilities. + * @param params the parameters of the notification. + * @return a Mono that completes once the notification is handled. + */ + Mono handle(McpAsyncServerExchange exchange, Object params); + + } + + /** + * A handler for client-initiated requests. + * + * @param the type of the response that is expected as a result of handling the + * request. + */ + public interface RequestHandler { + + /** + * Handles a request from the client. + * @param exchange the exchange associated with the client that allows calling + * back to the connected client or inspecting its capabilities. + * @param params the parameters of the request. + * @return a Mono that will emit the response to the request. + */ + Mono handle(McpAsyncServerExchange exchange, Object params); + + } + + /** + * Factory for creating server sessions which delegate to a provided 1:1 transport + * with a connected client. + */ + @FunctionalInterface + public interface Factory { + + /** + * Creates a new 1:1 representation of the client-server interaction. + * @param sessionTransport the transport to use for communication with the client. + * @return a new server session. + */ + McpServerSession create(McpServerTransport sessionTransport); + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java new file mode 100644 index 00000000..632b8cee --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java @@ -0,0 +1,11 @@ +package io.modelcontextprotocol.spec; + +/** + * Marker interface for the server-side MCP transport. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +public interface McpServerTransport extends McpTransport { + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java new file mode 100644 index 00000000..dba8cc43 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java @@ -0,0 +1,66 @@ +package io.modelcontextprotocol.spec; + +import java.util.Map; + +import reactor.core.publisher.Mono; + +/** + * The core building block providing the server-side MCP transport. Implement this + * interface to bridge between a particular server-side technology and the MCP server + * transport layer. + * + *

    + * The lifecycle of the provider dictates that it be created first, upon application + * startup, and then passed into either + * {@link io.modelcontextprotocol.server.McpServer#sync(McpServerTransportProvider)} or + * {@link io.modelcontextprotocol.server.McpServer#async(McpServerTransportProvider)}. As + * a result of the MCP server creation, the provider will be notified of a + * {@link McpServerSession.Factory} which will be used to handle a 1:1 communication + * between a newly connected client and the server. The provider's responsibility is to + * create instances of {@link McpServerTransport} that the session will utilise during the + * session lifetime. + * + *

    + * Finally, the {@link McpServerTransport}s can be closed in bulk when {@link #close()} or + * {@link #closeGracefully()} are called as part of the normal application shutdown event. + * Individual {@link McpServerTransport}s can also be closed on a per-session basis, where + * the {@link McpServerSession#close()} or {@link McpServerSession#closeGracefully()} + * closes the provided transport. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpServerTransportProvider { + + /** + * Sets the session factory that will be used to create sessions for new clients. An + * implementation of the MCP server MUST call this method before any MCP interactions + * take place. + * @param sessionFactory the session factory to be used for initiating client sessions + */ + void setSessionFactory(McpServerSession.Factory sessionFactory); + + /** + * Sends a notification to all connected clients. + * @param method the name of the notification method to be called on the clients + * @param params a map of parameters to be sent with the notification + * @return a Mono that completes when the notification has been broadcast + * @see McpSession#sendNotification(String, Map) + */ + Mono notifyClients(String method, Map params); + + /** + * Immediately closes all the transports with connected clients and releases any + * associated resources. + */ + default void close() { + this.closeGracefully().subscribe(); + } + + /** + * Gracefully closes all the transports with connected clients and releases any + * associated resources asynchronously. + * @return a {@link Mono} that completes when the connections have been closed. + */ + Mono closeGracefully(); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java index 92b46075..b97c3ccc 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java @@ -26,14 +26,15 @@ public interface McpSession { /** - * Sends a request to the model server and expects a response of type T. + * Sends a request to the model counterparty and expects a response of type T. * *

    * This method handles the request-response pattern where a response is expected from - * the server. The response type is determined by the provided TypeReference. + * the client or server. The response type is determined by the provided + * TypeReference. *

    * @param the type of the expected response - * @param method the name of the method to be called on the server + * @param method the name of the method to be called on the counterparty * @param requestParams the parameters to be sent with the request * @param typeRef the TypeReference describing the expected response type * @return a Mono that will emit the response when received @@ -41,11 +42,11 @@ public interface McpSession { Mono sendRequest(String method, Object requestParams, TypeReference typeRef); /** - * Sends a notification to the model server without parameters. + * Sends a notification to the model client or server without parameters. * *

    * This method implements the notification pattern where no response is expected from - * the server. It's useful for fire-and-forget scenarios. + * the counterparty. It's useful for fire-and-forget scenarios. *

    * @param method the name of the notification method to be called on the server * @return a Mono that completes when the notification has been sent @@ -55,13 +56,13 @@ default Mono sendNotification(String method) { } /** - * Sends a notification to the model server with parameters. + * Sends a notification to the model client or server with parameters. * *

    * Similar to {@link #sendNotification(String)} but allows sending additional * parameters with the notification. *

    - * @param method the name of the notification method to be called on the server + * @param method the name of the notification method to be sent to the counterparty * @param params a map of parameters to be sent with the notification * @return a Mono that completes when the notification has been sent */ diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java index 344a50bf..f698d878 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java @@ -46,8 +46,13 @@ public interface McpTransport { * This method should be called before any message exchange can occur. It sets up the * necessary resources and establishes the connection to the server. *

    + * @deprecated This is only relevant for client-side transports and will be removed + * from this interface in 0.9.0. */ - Mono connect(Function, Mono> handler); + @Deprecated + default Mono connect(Function, Mono> handler) { + return Mono.empty(); + } /** * Closes the transport connection and releases any associated resources. @@ -69,7 +74,7 @@ default void close() { Mono closeGracefully(); /** - * Sends a message to the server asynchronously. + * Sends a message to the peer asynchronously. * *

    * This method handles the transmission of messages to the server in an asynchronous diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java index 13591432..704daee0 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/ServerMcpTransport.java @@ -7,7 +7,9 @@ * Marker interface for the server-side MCP transport. * * @author Christian Tzolov + * @deprecated This class will be removed in 0.9.0. Use {@link McpServerTransport}. */ +@Deprecated public interface ServerMcpTransport extends McpTransport { } diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpTransport.java b/mcp/src/test/java/io/modelcontextprotocol/MockMcpTransport.java index d4e48ea7..12f30d12 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpTransport.java +++ b/mcp/src/test/java/io/modelcontextprotocol/MockMcpTransport.java @@ -11,7 +11,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.ServerMcpTransport; import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification; @@ -20,10 +20,10 @@ import reactor.core.publisher.Sinks; /** - * A mock implementation of the {@link ClientMcpTransport} and {@link ServerMcpTransport} + * A mock implementation of the {@link McpClientTransport} and {@link ServerMcpTransport} * interfaces. */ -public class MockMcpTransport implements ClientMcpTransport, ServerMcpTransport { +public class MockMcpTransport implements McpClientTransport, ServerMcpTransport { private final Sinks.Many inbound = Sinks.many().unicast().onBackpressureBuffer(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index f7a0a492..ac7b9e5e 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -12,7 +12,7 @@ import java.util.function.Consumer; import java.util.function.Function; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; @@ -50,7 +50,7 @@ public abstract class AbstractMcpAsyncClientTests { private static final String ECHO_TEST_MESSAGE = "Hello MCP Spring AI!"; - abstract protected ClientMcpTransport createMcpTransport(); + abstract protected McpClientTransport createMcpTransport(); protected void onStart() { } @@ -66,11 +66,11 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } - McpAsyncClient client(ClientMcpTransport transport) { + McpAsyncClient client(McpClientTransport transport) { return client(transport, Function.identity()); } - McpAsyncClient client(ClientMcpTransport transport, Function customizer) { + McpAsyncClient client(McpClientTransport transport, Function customizer) { AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { @@ -85,11 +85,11 @@ McpAsyncClient client(ClientMcpTransport transport, Function c) { + void withClient(McpClientTransport transport, Consumer c) { withClient(transport, Function.identity(), c); } - void withClient(ClientMcpTransport transport, Function customizer, + void withClient(McpClientTransport transport, Function customizer, Consumer c) { var client = client(transport, customizer); try { diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index f4d8dbdb..24c161eb 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -11,7 +11,7 @@ import java.util.function.Consumer; import java.util.function.Function; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; @@ -50,7 +50,7 @@ public abstract class AbstractMcpSyncClientTests { private static final String TEST_MESSAGE = "Hello MCP Spring AI!"; - abstract protected ClientMcpTransport createMcpTransport(); + abstract protected McpClientTransport createMcpTransport(); protected void onStart() { } @@ -66,11 +66,11 @@ protected Duration getInitializationTimeout() { return Duration.ofSeconds(2); } - McpSyncClient client(ClientMcpTransport transport) { + McpSyncClient client(McpClientTransport transport) { return client(transport, Function.identity()); } - McpSyncClient client(ClientMcpTransport transport, Function customizer) { + McpSyncClient client(McpClientTransport transport, Function customizer) { AtomicReference client = new AtomicReference<>(); assertThatCode(() -> { @@ -85,11 +85,11 @@ McpSyncClient client(ClientMcpTransport transport, Function c) { + void withClient(McpClientTransport transport, Consumer c) { withClient(transport, Function.identity(), c); } - void withClient(ClientMcpTransport transport, Function customizer, + void withClient(McpClientTransport transport, Function customizer, Consumer c) { var client = client(transport, customizer); try { diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java index ac0fef24..15749d4f 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java @@ -4,10 +4,8 @@ package io.modelcontextprotocol.client; -import java.time.Duration; - import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; @@ -30,7 +28,7 @@ class HttpSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests { .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected ClientMcpTransport createMcpTransport() { + protected McpClientTransport createMcpTransport() { return new HttpClientSseClientTransport(host); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java index 8772e620..067f9295 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpSyncClientTests.java @@ -4,10 +4,8 @@ package io.modelcontextprotocol.client; -import java.time.Duration; - import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; @@ -30,7 +28,7 @@ class HttpSseMcpSyncClientTests extends AbstractMcpSyncClientTests { .waitingFor(Wait.forHttp("/").forStatusCode(404)); @Override - protected ClientMcpTransport createMcpTransport() { + protected McpClientTransport createMcpTransport() { return new HttpClientSseClientTransport(host); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java index c285e2c6..95230942 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java @@ -8,7 +8,7 @@ import io.modelcontextprotocol.client.transport.ServerParameters; import io.modelcontextprotocol.client.transport.StdioClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; /** @@ -21,7 +21,7 @@ class StdioMcpAsyncClientTests extends AbstractMcpAsyncClientTests { @Override - protected ClientMcpTransport createMcpTransport() { + protected McpClientTransport createMcpTransport() { ServerParameters stdioParams = ServerParameters.builder("npx") .args("-y", "@modelcontextprotocol/server-everything", "dir") .build(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java index ebf10b9a..925852b5 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java @@ -11,7 +11,7 @@ import io.modelcontextprotocol.client.transport.ServerParameters; import io.modelcontextprotocol.client.transport.StdioClientTransport; -import io.modelcontextprotocol.spec.ClientMcpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import reactor.core.publisher.Sinks; @@ -29,7 +29,7 @@ class StdioMcpSyncClientTests extends AbstractMcpSyncClientTests { @Override - protected ClientMcpTransport createMcpTransport() { + protected McpClientTransport createMcpTransport() { ServerParameters stdioParams = ServerParameters.builder("npx") .args("-y", "@modelcontextprotocol/server-everything", "dir") .build(); @@ -42,7 +42,7 @@ void customErrorHandlerShouldReceiveErrors() throws InterruptedException { CountDownLatch latch = new CountDownLatch(1); AtomicReference receivedError = new AtomicReference<>(); - ClientMcpTransport transport = createMcpTransport(); + McpClientTransport transport = createMcpTransport(); StepVerifier.create(transport.connect(msg -> msg)).verifyComplete(); ((StdioClientTransport) transport).setStdErrorHandler(error -> { diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java new file mode 100644 index 00000000..b9a19de6 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerDeprecatedTests.java @@ -0,0 +1,466 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.List; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Test suite for the {@link McpAsyncServer} that can be used with different + * {@link McpTransport} implementations. + * + * @author Christian Tzolov + */ +// KEEP IN SYNC with the class in mcp-test module +@Deprecated +public abstract class AbstractMcpAsyncServerDeprecatedTests { + + private static final String TEST_TOOL_NAME = "test-tool"; + + private static final String TEST_RESOURCE_URI = "test://resource"; + + private static final String TEST_PROMPT_NAME = "test-prompt"; + + abstract protected ServerMcpTransport createMcpTransport(); + + protected void onStart() { + } + + protected void onClose() { + } + + @BeforeEach + void setUp() { + } + + @AfterEach + void tearDown() { + onClose(); + } + + // --------------------------------------- + // Server Lifecycle Tests + // --------------------------------------- + + @Test + void testConstructorWithInvalidArguments() { + assertThatThrownBy(() -> McpServer.async((ServerMcpTransport) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport must not be null"); + + assertThatThrownBy(() -> McpServer.async(createMcpTransport()).serverInfo((McpSchema.Implementation) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Server info must not be null"); + } + + @Test + void testGracefulShutdown() { + var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(mcpAsyncServer.closeGracefully()).verifyComplete(); + } + + @Test + void testImmediateClose() { + var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @Test + void testAddTool() { + Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(newTool, + args -> Mono.just(new CallToolResult(List.of(), false))))) + .verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddDuplicateTool() { + Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(duplicateTool, args -> Mono.just(new CallToolResult(List.of(), false))) + .build(); + + StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(duplicateTool, + args -> Mono.just(new CallToolResult(List.of(), false))))) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); + }); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveTool() { + Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(too, args -> Mono.just(new CallToolResult(List.of(), false))) + .build(); + + StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentTool() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.removeTool("nonexistent-tool")).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class).hasMessage("Tool with name 'nonexistent-tool' not found"); + }); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testNotifyToolsListChanged() { + Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(too, args -> Mono.just(new CallToolResult(List.of(), false))) + .build(); + + StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resources Tests + // --------------------------------------- + + @Test + void testNotifyResourcesListChanged() { + var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(mcpAsyncServer.notifyResourcesListChanged()).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResource() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", + null); + McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration( + resource, req -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(mcpAsyncServer.addResource(registration)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithNullRegistration() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceRegistration) null)) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class).hasMessage("Resource must not be null"); + }); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .build(); + + Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", + null); + McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration( + resource, req -> Mono.just(new ReadResourceResult(List.of()))); + + StepVerifier.create(serverWithoutResources.addResource(registration)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with resource capabilities"); + }); + } + + @Test + void testRemoveResourceWithoutCapability() { + // Create a server without resource capabilities + McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .build(); + + StepVerifier.create(serverWithoutResources.removeResource(TEST_RESOURCE_URI)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with resource capabilities"); + }); + } + + // --------------------------------------- + // Prompts Tests + // --------------------------------------- + + @Test + void testNotifyPromptsListChanged() { + var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + StepVerifier.create(mcpAsyncServer.notifyPromptsListChanged()).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testAddPromptWithNullRegistration() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(false).build()) + .build(); + + StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptRegistration) null)) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt registration must not be null"); + }); + } + + @Test + void testAddPromptWithoutCapability() { + // Create a server without prompt capabilities + McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .build(); + + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); + McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt, + req -> Mono.just(new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); + + StepVerifier.create(serverWithoutPrompts.addPrompt(registration)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with prompt capabilities"); + }); + } + + @Test + void testRemovePromptWithoutCapability() { + // Create a server without prompt capabilities + McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .build(); + + StepVerifier.create(serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with prompt capabilities"); + }); + } + + @Test + void testRemovePrompt() { + String TEST_PROMPT_NAME_TO_REMOVE = "TEST_PROMPT_NAME678"; + + Prompt prompt = new Prompt(TEST_PROMPT_NAME_TO_REMOVE, "Test Prompt", List.of()); + McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt, + req -> Mono.just(new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); + + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .prompts(registration) + .build(); + + StepVerifier.create(mcpAsyncServer.removePrompt(TEST_PROMPT_NAME_TO_REMOVE)).verifyComplete(); + + assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentPrompt() { + var mcpAsyncServer2 = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .build(); + + StepVerifier.create(mcpAsyncServer2.removePrompt("nonexistent-prompt")).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class) + .hasMessage("Prompt with name 'nonexistent-prompt' not found"); + }); + + assertThatCode(() -> mcpAsyncServer2.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + + @Test + void testRootsChangeConsumers() { + // Test with single consumer + var rootsReceived = new McpSchema.Root[1]; + var consumerCalled = new boolean[1]; + + var singleConsumerServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> { + consumerCalled[0] = true; + if (!roots.isEmpty()) { + rootsReceived[0] = roots.get(0); + } + }))) + .build(); + + assertThat(singleConsumerServer).isNotNull(); + assertThatCode(() -> singleConsumerServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + onClose(); + + // Test with multiple consumers + var consumer1Called = new boolean[1]; + var consumer2Called = new boolean[1]; + var rootsContent = new List[1]; + + var multipleConsumersServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> { + consumer1Called[0] = true; + rootsContent[0] = roots; + }), roots -> Mono.fromRunnable(() -> consumer2Called[0] = true))) + .build(); + + assertThat(multipleConsumersServer).isNotNull(); + assertThatCode(() -> multipleConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + onClose(); + + // Test error handling + var errorHandlingServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .rootsChangeConsumers(List.of(roots -> { + throw new RuntimeException("Test error"); + })) + .build(); + + assertThat(errorHandlingServer).isNotNull(); + assertThatCode(() -> errorHandlingServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + onClose(); + + // Test without consumers + var noConsumersServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThat(noConsumersServer).isNotNull(); + assertThatCode(() -> noConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) + .doesNotThrowAnyException(); + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + + @Test + void testLoggingLevels() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().build()) + .build(); + + // Test all logging levels + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + var notification = McpSchema.LoggingMessageNotification.builder() + .level(level) + .logger("test-logger") + .data("Test message with level " + level) + .build(); + + StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); + } + } + + @Test + void testLoggingWithoutCapability() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().build()) // No logging capability + .build(); + + var notification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Test log message") + .build(); + + StepVerifier.create(mcpAsyncServer.loggingNotification(notification)).verifyComplete(); + } + + @Test + void testLoggingWithNullNotification() { + var mcpAsyncServer = McpServer.async(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().build()) + .build(); + + StepVerifier.create(mcpAsyncServer.loggingNotification(null)).verifyError(McpError.class); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index dcc103b5..4b4fc434 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -17,8 +17,7 @@ import io.modelcontextprotocol.spec.McpSchema.Resource; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -31,11 +30,10 @@ /** * Test suite for the {@link McpAsyncServer} that can be used with different - * {@link McpTransport} implementations. + * {@link McpTransportProvider} implementations. * * @author Christian Tzolov */ -// KEEP IN SYNC with the class in mcp-test module public abstract class AbstractMcpAsyncServerTests { private static final String TEST_TOOL_NAME = "test-tool"; @@ -44,7 +42,7 @@ public abstract class AbstractMcpAsyncServerTests { private static final String TEST_PROMPT_NAME = "test-prompt"; - abstract protected ServerMcpTransport createMcpTransport(); + abstract protected McpServerTransportProvider createMcpTransportProvider(); protected void onStart() { } @@ -67,24 +65,26 @@ void tearDown() { @Test void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.async(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport must not be null"); + assertThatThrownBy(() -> McpServer.async((McpServerTransportProvider) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport provider must not be null"); - assertThatThrownBy(() -> McpServer.async(createMcpTransport()).serverInfo((McpSchema.Implementation) null)) + assertThatThrownBy( + () -> McpServer.async(createMcpTransportProvider()).serverInfo((McpSchema.Implementation) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Server info must not be null"); } @Test void testGracefulShutdown() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.closeGracefully()).verifyComplete(); } @Test void testImmediateClose() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException(); } @@ -103,13 +103,13 @@ void testImmediateClose() { @Test void testAddTool() { Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); - StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(newTool, - args -> Mono.just(new CallToolResult(List.of(), false))))) + StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(newTool, + (excnage, args) -> Mono.just(new CallToolResult(List.of(), false))))) .verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); @@ -119,14 +119,15 @@ void testAddTool() { void testAddDuplicateTool() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, args -> Mono.just(new CallToolResult(List.of(), false))) + .tool(duplicateTool, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) .build(); - StepVerifier.create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolRegistration(duplicateTool, - args -> Mono.just(new CallToolResult(List.of(), false))))) + StepVerifier + .create(mcpAsyncServer.addTool(new McpServerFeatures.AsyncToolSpecification(duplicateTool, + (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))))) .verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); @@ -139,10 +140,10 @@ void testAddDuplicateTool() { void testRemoveTool() { Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, args -> Mono.just(new CallToolResult(List.of(), false))) + .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier.create(mcpAsyncServer.removeTool(TEST_TOOL_NAME)).verifyComplete(); @@ -152,7 +153,7 @@ void testRemoveTool() { @Test void testRemoveNonexistentTool() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -168,10 +169,10 @@ void testRemoveNonexistentTool() { void testNotifyToolsListChanged() { Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(too, args -> Mono.just(new CallToolResult(List.of(), false))) + .tool(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) .build(); StepVerifier.create(mcpAsyncServer.notifyToolsListChanged()).verifyComplete(); @@ -185,7 +186,7 @@ void testNotifyToolsListChanged() { @Test void testNotifyResourcesListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.notifyResourcesListChanged()).verifyComplete(); @@ -194,29 +195,29 @@ void testNotifyResourcesListChanged() { @Test void testAddResource() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); - McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration( - resource, req -> Mono.just(new ReadResourceResult(List.of()))); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); - StepVerifier.create(mcpAsyncServer.addResource(registration)).verifyComplete(); + StepVerifier.create(mcpAsyncServer.addResource(specification)).verifyComplete(); assertThatCode(() -> mcpAsyncServer.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } @Test - void testAddResourceWithNullRegistration() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + void testAddResourceWithNullSpecification() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); - StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceRegistration) null)) + StepVerifier.create(mcpAsyncServer.addResource((McpServerFeatures.AsyncResourceSpecification) null)) .verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class).hasMessage("Resource must not be null"); }); @@ -227,16 +228,16 @@ void testAddResourceWithNullRegistration() { @Test void testAddResourceWithoutCapability() { // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport()) + McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .build(); Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); - McpServerFeatures.AsyncResourceRegistration registration = new McpServerFeatures.AsyncResourceRegistration( - resource, req -> Mono.just(new ReadResourceResult(List.of()))); + McpServerFeatures.AsyncResourceSpecification specification = new McpServerFeatures.AsyncResourceSpecification( + resource, (exchange, req) -> Mono.just(new ReadResourceResult(List.of()))); - StepVerifier.create(serverWithoutResources.addResource(registration)).verifyErrorSatisfies(error -> { + StepVerifier.create(serverWithoutResources.addResource(specification)).verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) .hasMessage("Server must be configured with resource capabilities"); }); @@ -245,7 +246,7 @@ void testAddResourceWithoutCapability() { @Test void testRemoveResourceWithoutCapability() { // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransport()) + McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .build(); @@ -261,7 +262,7 @@ void testRemoveResourceWithoutCapability() { @Test void testNotifyPromptsListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.notifyPromptsListChanged()).verifyComplete(); @@ -269,31 +270,31 @@ void testNotifyPromptsListChanged() { } @Test - void testAddPromptWithNullRegistration() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + void testAddPromptWithNullSpecification() { + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(false).build()) .build(); - StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptRegistration) null)) + StepVerifier.create(mcpAsyncServer.addPrompt((McpServerFeatures.AsyncPromptSpecification) null)) .verifyErrorSatisfies(error -> { - assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt registration must not be null"); + assertThat(error).isInstanceOf(McpError.class).hasMessage("Prompt specification must not be null"); }); } @Test void testAddPromptWithoutCapability() { // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport()) + McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .build(); Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt, - req -> Mono.just(new GetPromptResult("Test prompt description", List + McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( + prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - StepVerifier.create(serverWithoutPrompts.addPrompt(registration)).verifyErrorSatisfies(error -> { + StepVerifier.create(serverWithoutPrompts.addPrompt(specification)).verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) .hasMessage("Server must be configured with prompt capabilities"); }); @@ -302,7 +303,7 @@ void testAddPromptWithoutCapability() { @Test void testRemovePromptWithoutCapability() { // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransport()) + McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .build(); @@ -317,14 +318,14 @@ void testRemovePrompt() { String TEST_PROMPT_NAME_TO_REMOVE = "TEST_PROMPT_NAME678"; Prompt prompt = new Prompt(TEST_PROMPT_NAME_TO_REMOVE, "Test Prompt", List.of()); - McpServerFeatures.AsyncPromptRegistration registration = new McpServerFeatures.AsyncPromptRegistration(prompt, - req -> Mono.just(new GetPromptResult("Test prompt description", List + McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( + prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) - .prompts(registration) + .prompts(specification) .build(); StepVerifier.create(mcpAsyncServer.removePrompt(TEST_PROMPT_NAME_TO_REMOVE)).verifyComplete(); @@ -334,7 +335,7 @@ void testRemovePrompt() { @Test void testRemoveNonexistentPrompt() { - var mcpAsyncServer2 = McpServer.async(createMcpTransport()) + var mcpAsyncServer2 = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); @@ -353,14 +354,14 @@ void testRemoveNonexistentPrompt() { // --------------------------------------- @Test - void testRootsChangeConsumers() { + void testRootsChangeHandlers() { // Test with single consumer var rootsReceived = new McpSchema.Root[1]; var consumerCalled = new boolean[1]; - var singleConsumerServer = McpServer.async(createMcpTransport()) + var singleConsumerServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> { + .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { consumerCalled[0] = true; if (!roots.isEmpty()) { rootsReceived[0] = roots.get(0); @@ -378,12 +379,12 @@ void testRootsChangeConsumers() { var consumer2Called = new boolean[1]; var rootsContent = new List[1]; - var multipleConsumersServer = McpServer.async(createMcpTransport()) + var multipleConsumersServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> Mono.fromRunnable(() -> { + .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { consumer1Called[0] = true; rootsContent[0] = roots; - }), roots -> Mono.fromRunnable(() -> consumer2Called[0] = true))) + }), (exchange, roots) -> Mono.fromRunnable(() -> consumer2Called[0] = true))) .build(); assertThat(multipleConsumersServer).isNotNull(); @@ -392,9 +393,9 @@ void testRootsChangeConsumers() { onClose(); // Test error handling - var errorHandlingServer = McpServer.async(createMcpTransport()) + var errorHandlingServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { + .rootsChangeHandlers(List.of((exchange, roots) -> { throw new RuntimeException("Test error"); })) .build(); @@ -405,7 +406,9 @@ void testRootsChangeConsumers() { onClose(); // Test without consumers - var noConsumersServer = McpServer.async(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var noConsumersServer = McpServer.async(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); assertThat(noConsumersServer).isNotNull(); assertThatCode(() -> noConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) @@ -418,7 +421,7 @@ void testRootsChangeConsumers() { @Test void testLoggingLevels() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().logging().build()) .build(); @@ -437,7 +440,7 @@ void testLoggingLevels() { @Test void testLoggingWithoutCapability() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().build()) // No logging capability .build(); @@ -453,7 +456,7 @@ void testLoggingWithoutCapability() { @Test void testLoggingWithNullNotification() { - var mcpAsyncServer = McpServer.async(createMcpTransport()) + var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().logging().build()) .build(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java new file mode 100644 index 00000000..16bc2d6e --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerDeprecatedTests.java @@ -0,0 +1,433 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.util.List; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Test suite for the {@link McpSyncServer} that can be used with different + * {@link McpTransport} implementations. + * + * @author Christian Tzolov + */ +// KEEP IN SYNC with the class in mcp-test module +@Deprecated +public abstract class AbstractMcpSyncServerDeprecatedTests { + + private static final String TEST_TOOL_NAME = "test-tool"; + + private static final String TEST_RESOURCE_URI = "test://resource"; + + private static final String TEST_PROMPT_NAME = "test-prompt"; + + abstract protected ServerMcpTransport createMcpTransport(); + + protected void onStart() { + } + + protected void onClose() { + } + + @BeforeEach + void setUp() { + // onStart(); + } + + @AfterEach + void tearDown() { + onClose(); + } + + // --------------------------------------- + // Server Lifecycle Tests + // --------------------------------------- + + @Test + void testConstructorWithInvalidArguments() { + assertThatThrownBy(() -> McpServer.sync((ServerMcpTransport) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport must not be null"); + + assertThatThrownBy(() -> McpServer.sync(createMcpTransport()).serverInfo(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Server info must not be null"); + } + + @Test + void testGracefulShutdown() { + var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testImmediateClose() { + var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.close()).doesNotThrowAnyException(); + } + + @Test + void testGetAsyncServer() { + var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @Test + void testAddTool() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); + assertThatCode(() -> mcpSyncServer + .addTool(new McpServerFeatures.SyncToolRegistration(newTool, args -> new CallToolResult(List.of(), false)))) + .doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddDuplicateTool() { + Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); + + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(duplicateTool, args -> new CallToolResult(List.of(), false)) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolRegistration(duplicateTool, + args -> new CallToolResult(List.of(), false)))) + .isInstanceOf(McpError.class) + .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testRemoveTool() { + Tool tool = new McpSchema.Tool(TEST_TOOL_NAME, "Test tool", emptyJsonSchema); + + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tool(tool, args -> new CallToolResult(List.of(), false)) + .build(); + + assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentTool() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.removeTool("nonexistent-tool")).isInstanceOf(McpError.class) + .hasMessage("Tool with name 'nonexistent-tool' not found"); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testNotifyToolsListChanged() { + var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.notifyToolsListChanged()).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Resources Tests + // --------------------------------------- + + @Test + void testNotifyResourcesListChanged() { + var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.notifyResourcesListChanged()).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddResource() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", + null); + McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration( + resource, req -> new ReadResourceResult(List.of())); + + assertThatCode(() -> mcpSyncServer.addResource(registration)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithNullRegistration() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().resources(true, false).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceRegistration) null)) + .isInstanceOf(McpError.class) + .hasMessage("Resource must not be null"); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddResourceWithoutCapability() { + var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", + null); + McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration( + resource, req -> new ReadResourceResult(List.of())); + + assertThatThrownBy(() -> serverWithoutResources.addResource(registration)).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with resource capabilities"); + } + + @Test + void testRemoveResourceWithoutCapability() { + var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with resource capabilities"); + } + + // --------------------------------------- + // Prompts Tests + // --------------------------------------- + + @Test + void testNotifyPromptsListChanged() { + var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatCode(() -> mcpSyncServer.notifyPromptsListChanged()).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testAddPromptWithNullRegistration() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(false).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptRegistration) null)) + .isInstanceOf(McpError.class) + .hasMessage("Prompt registration must not be null"); + } + + @Test + void testAddPromptWithoutCapability() { + var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); + McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt, + req -> new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); + + assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(registration)).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with prompt capabilities"); + } + + @Test + void testRemovePromptWithoutCapability() { + var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).isInstanceOf(McpError.class) + .hasMessage("Server must be configured with prompt capabilities"); + } + + @Test + void testRemovePrompt() { + Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); + McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt, + req -> new GetPromptResult("Test prompt description", List + .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); + + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .prompts(registration) + .build(); + + assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + @Test + void testRemoveNonexistentPrompt() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().prompts(true).build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.removePrompt("nonexistent-prompt")).isInstanceOf(McpError.class) + .hasMessage("Prompt with name 'nonexistent-prompt' not found"); + + assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + + @Test + void testRootsChangeConsumers() { + // Test with single consumer + var rootsReceived = new McpSchema.Root[1]; + var consumerCalled = new boolean[1]; + + var singleConsumerServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .rootsChangeConsumers(List.of(roots -> { + consumerCalled[0] = true; + if (!roots.isEmpty()) { + rootsReceived[0] = roots.get(0); + } + })) + .build(); + + assertThat(singleConsumerServer).isNotNull(); + assertThatCode(() -> singleConsumerServer.closeGracefully()).doesNotThrowAnyException(); + onClose(); + + // Test with multiple consumers + var consumer1Called = new boolean[1]; + var consumer2Called = new boolean[1]; + var rootsContent = new List[1]; + + var multipleConsumersServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .rootsChangeConsumers(List.of(roots -> { + consumer1Called[0] = true; + rootsContent[0] = roots; + }, roots -> consumer2Called[0] = true)) + .build(); + + assertThat(multipleConsumersServer).isNotNull(); + assertThatCode(() -> multipleConsumersServer.closeGracefully()).doesNotThrowAnyException(); + onClose(); + + // Test error handling + var errorHandlingServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .rootsChangeConsumers(List.of(roots -> { + throw new RuntimeException("Test error"); + })) + .build(); + + assertThat(errorHandlingServer).isNotNull(); + assertThatCode(() -> errorHandlingServer.closeGracefully()).doesNotThrowAnyException(); + onClose(); + + // Test without consumers + var noConsumersServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + + assertThat(noConsumersServer).isNotNull(); + assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + + @Test + void testLoggingLevels() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().build()) + .build(); + + // Test all logging levels + for (McpSchema.LoggingLevel level : McpSchema.LoggingLevel.values()) { + var notification = McpSchema.LoggingMessageNotification.builder() + .level(level) + .logger("test-logger") + .data("Test message with level " + level) + .build(); + + assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); + } + } + + @Test + void testLoggingWithoutCapability() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().build()) // No logging capability + .build(); + + var notification = McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Test log message") + .build(); + + assertThatCode(() -> mcpSyncServer.loggingNotification(notification)).doesNotThrowAnyException(); + } + + @Test + void testLoggingWithNullNotification() { + var mcpSyncServer = McpServer.sync(createMcpTransport()) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().build()) + .build(); + + assertThatThrownBy(() -> mcpSyncServer.loggingNotification(null)).isInstanceOf(McpError.class); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index bdcd7ae3..17feb36e 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -16,8 +16,7 @@ import io.modelcontextprotocol.spec.McpSchema.Resource; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -28,11 +27,10 @@ /** * Test suite for the {@link McpSyncServer} that can be used with different - * {@link McpTransport} implementations. + * {@link McpTransportProvider} implementations. * * @author Christian Tzolov */ -// KEEP IN SYNC with the class in mcp-test module public abstract class AbstractMcpSyncServerTests { private static final String TEST_TOOL_NAME = "test-tool"; @@ -41,7 +39,7 @@ public abstract class AbstractMcpSyncServerTests { private static final String TEST_PROMPT_NAME = "test-prompt"; - abstract protected ServerMcpTransport createMcpTransport(); + abstract protected McpServerTransportProvider createMcpTransportProvider(); protected void onStart() { } @@ -65,31 +63,32 @@ void tearDown() { @Test void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> McpServer.sync(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Transport must not be null"); + assertThatThrownBy(() -> McpServer.sync((McpServerTransportProvider) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Transport provider must not be null"); - assertThatThrownBy(() -> McpServer.sync(createMcpTransport()).serverInfo(null)) + assertThatThrownBy(() -> McpServer.sync(createMcpTransportProvider()).serverInfo(null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Server info must not be null"); } @Test void testGracefulShutdown() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); } @Test void testImmediateClose() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.close()).doesNotThrowAnyException(); } @Test void testGetAsyncServer() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); @@ -110,14 +109,14 @@ void testGetAsyncServer() { @Test void testAddTool() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - assertThatCode(() -> mcpSyncServer - .addTool(new McpServerFeatures.SyncToolRegistration(newTool, args -> new CallToolResult(List.of(), false)))) + assertThatCode(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(newTool, + (exchange, args) -> new CallToolResult(List.of(), false)))) .doesNotThrowAnyException(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); @@ -127,14 +126,14 @@ void testAddTool() { void testAddDuplicateTool() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(duplicateTool, args -> new CallToolResult(List.of(), false)) + .tool(duplicateTool, (exchange, args) -> new CallToolResult(List.of(), false)) .build(); - assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolRegistration(duplicateTool, - args -> new CallToolResult(List.of(), false)))) + assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool, + (exchange, args) -> new CallToolResult(List.of(), false)))) .isInstanceOf(McpError.class) .hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists"); @@ -145,10 +144,10 @@ void testAddDuplicateTool() { void testRemoveTool() { Tool tool = new McpSchema.Tool(TEST_TOOL_NAME, "Test tool", emptyJsonSchema); - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) - .tool(tool, args -> new CallToolResult(List.of(), false)) + .tool(tool, (exchange, args) -> new CallToolResult(List.of(), false)) .build(); assertThatCode(() -> mcpSyncServer.removeTool(TEST_TOOL_NAME)).doesNotThrowAnyException(); @@ -158,7 +157,7 @@ void testRemoveTool() { @Test void testRemoveNonexistentTool() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -171,7 +170,7 @@ void testRemoveNonexistentTool() { @Test void testNotifyToolsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.notifyToolsListChanged()).doesNotThrowAnyException(); @@ -184,7 +183,7 @@ void testNotifyToolsListChanged() { @Test void testNotifyResourcesListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.notifyResourcesListChanged()).doesNotThrowAnyException(); @@ -193,29 +192,29 @@ void testNotifyResourcesListChanged() { @Test void testAddResource() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); - McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration( - resource, req -> new ReadResourceResult(List.of())); + McpServerFeatures.SyncResourceSpecification specificaiton = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); - assertThatCode(() -> mcpSyncServer.addResource(registration)).doesNotThrowAnyException(); + assertThatCode(() -> mcpSyncServer.addResource(specificaiton)).doesNotThrowAnyException(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); } @Test - void testAddResourceWithNullRegistration() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + void testAddResourceWithNullSpecifiation() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceRegistration) null)) + assertThatThrownBy(() -> mcpSyncServer.addResource((McpServerFeatures.SyncResourceSpecification) null)) .isInstanceOf(McpError.class) .hasMessage("Resource must not be null"); @@ -224,20 +223,24 @@ void testAddResourceWithNullRegistration() { @Test void testAddResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var serverWithoutResources = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); - McpServerFeatures.SyncResourceRegistration registration = new McpServerFeatures.SyncResourceRegistration( - resource, req -> new ReadResourceResult(List.of())); + McpServerFeatures.SyncResourceSpecification specification = new McpServerFeatures.SyncResourceSpecification( + resource, (exchange, req) -> new ReadResourceResult(List.of())); - assertThatThrownBy(() -> serverWithoutResources.addResource(registration)).isInstanceOf(McpError.class) + assertThatThrownBy(() -> serverWithoutResources.addResource(specification)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with resource capabilities"); } @Test void testRemoveResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var serverWithoutResources = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with resource capabilities"); @@ -249,7 +252,7 @@ void testRemoveResourceWithoutCapability() { @Test void testNotifyPromptsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.notifyPromptsListChanged()).doesNotThrowAnyException(); @@ -257,33 +260,37 @@ void testNotifyPromptsListChanged() { } @Test - void testAddPromptWithNullRegistration() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + void testAddPromptWithNullSpecification() { + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(false).build()) .build(); - assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptRegistration) null)) + assertThatThrownBy(() -> mcpSyncServer.addPrompt((McpServerFeatures.SyncPromptSpecification) null)) .isInstanceOf(McpError.class) - .hasMessage("Prompt registration must not be null"); + .hasMessage("Prompt specification must not be null"); } @Test void testAddPromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt, - req -> new GetPromptResult("Test prompt description", List + McpServerFeatures.SyncPromptSpecification specificaiton = new McpServerFeatures.SyncPromptSpecification(prompt, + (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(registration)).isInstanceOf(McpError.class) + assertThatThrownBy(() -> serverWithoutPrompts.addPrompt(specificaiton)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with prompt capabilities"); } @Test void testRemovePromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider()) + .serverInfo("test-server", "1.0.0") + .build(); assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with prompt capabilities"); @@ -292,14 +299,14 @@ void testRemovePromptWithoutCapability() { @Test void testRemovePrompt() { Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", List.of()); - McpServerFeatures.SyncPromptRegistration registration = new McpServerFeatures.SyncPromptRegistration(prompt, - req -> new GetPromptResult("Test prompt description", List + McpServerFeatures.SyncPromptSpecification specificaiton = new McpServerFeatures.SyncPromptSpecification(prompt, + (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) - .prompts(registration) + .prompts(specificaiton) .build(); assertThatCode(() -> mcpSyncServer.removePrompt(TEST_PROMPT_NAME)).doesNotThrowAnyException(); @@ -309,7 +316,7 @@ void testRemovePrompt() { @Test void testRemoveNonexistentPrompt() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); @@ -325,14 +332,14 @@ void testRemoveNonexistentPrompt() { // --------------------------------------- @Test - void testRootsChangeConsumers() { + void testRootsChangeHandlers() { // Test with single consumer var rootsReceived = new McpSchema.Root[1]; var consumerCalled = new boolean[1]; - var singleConsumerServer = McpServer.sync(createMcpTransport()) + var singleConsumerServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { + .rootsChangeHandlers(List.of((exchage, roots) -> { consumerCalled[0] = true; if (!roots.isEmpty()) { rootsReceived[0] = roots.get(0); @@ -349,12 +356,12 @@ void testRootsChangeConsumers() { var consumer2Called = new boolean[1]; var rootsContent = new List[1]; - var multipleConsumersServer = McpServer.sync(createMcpTransport()) + var multipleConsumersServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { + .rootsChangeHandlers(List.of((exchange, roots) -> { consumer1Called[0] = true; rootsContent[0] = roots; - }, roots -> consumer2Called[0] = true)) + }, (exchange, roots) -> consumer2Called[0] = true)) .build(); assertThat(multipleConsumersServer).isNotNull(); @@ -362,9 +369,9 @@ void testRootsChangeConsumers() { onClose(); // Test error handling - var errorHandlingServer = McpServer.sync(createMcpTransport()) + var errorHandlingServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") - .rootsChangeConsumers(List.of(roots -> { + .rootsChangeHandlers(List.of((exchange, roots) -> { throw new RuntimeException("Test error"); })) .build(); @@ -374,7 +381,7 @@ void testRootsChangeConsumers() { onClose(); // Test without consumers - var noConsumersServer = McpServer.sync(createMcpTransport()).serverInfo("test-server", "1.0.0").build(); + var noConsumersServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); assertThat(noConsumersServer).isNotNull(); assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); @@ -386,7 +393,7 @@ void testRootsChangeConsumers() { @Test void testLoggingLevels() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().logging().build()) .build(); @@ -405,7 +412,7 @@ void testLoggingLevels() { @Test void testLoggingWithoutCapability() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().build()) // No logging capability .build(); @@ -421,7 +428,7 @@ void testLoggingWithoutCapability() { @Test void testLoggingWithNullNotification() { - var mcpSyncServer = McpServer.sync(createMcpTransport()) + var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) .serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().logging().build()) .build(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/BaseMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/BaseMcpAsyncServerTests.java new file mode 100644 index 00000000..208bcb71 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/BaseMcpAsyncServerTests.java @@ -0,0 +1,5 @@ +package io.modelcontextprotocol.server; + +public abstract class BaseMcpAsyncServerTests { + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerDeprecatedTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerDeprecatedTests.java new file mode 100644 index 00000000..2c80d45c --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerDeprecatedTests.java @@ -0,0 +1,26 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.Timeout; + +/** + * Tests for {@link McpAsyncServer} using {@link HttpServletSseServerTransport}. + * + * @author Christian Tzolov + */ +@Deprecated +@Timeout(15) // Giving extra time beyond the client timeout +class ServletSseMcpAsyncServerDeprecatedTests extends AbstractMcpAsyncServerDeprecatedTests { + + @Override + protected ServerMcpTransport createMcpTransport() { + return new HttpServletSseServerTransport(new ObjectMapper(), "/mcp/message"); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java index 715f636d..9de186b4 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java @@ -5,12 +5,12 @@ package io.modelcontextprotocol.server; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.HttpServletSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; /** - * Tests for {@link McpAsyncServer} using {@link HttpServletSseServerTransport}. + * Tests for {@link McpAsyncServer} using {@link HttpServletSseServerTransportProvider}. * * @author Christian Tzolov */ @@ -18,8 +18,8 @@ class ServletSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { @Override - protected ServerMcpTransport createMcpTransport() { - return new HttpServletSseServerTransport(new ObjectMapper(), "/mcp/message"); + protected McpServerTransportProvider createMcpTransportProvider() { + return new HttpServletSseServerTransportProvider(new ObjectMapper(), "/mcp/message"); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerDeprecatedTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerDeprecatedTests.java new file mode 100644 index 00000000..8cdd08c5 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerDeprecatedTests.java @@ -0,0 +1,26 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.Timeout; + +/** + * Tests for {@link McpSyncServer} using {@link HttpServletSseServerTransport}. + * + * @author Christian Tzolov + */ +@Deprecated +@Timeout(15) // Giving extra time beyond the client timeout +class ServletSseMcpSyncServerDeprecatedTests extends AbstractMcpSyncServerDeprecatedTests { + + @Override + protected ServerMcpTransport createMcpTransport() { + return new HttpServletSseServerTransport(new ObjectMapper(), "/mcp/message"); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java index 208de7f7..60dc53a4 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java @@ -5,12 +5,12 @@ package io.modelcontextprotocol.server; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.server.transport.HttpServletSseServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; /** - * Tests for {@link McpSyncServer} using {@link HttpServletSseServerTransport}. + * Tests for {@link McpSyncServer} using {@link HttpServletSseServerTransportProvider}. * * @author Christian Tzolov */ @@ -18,8 +18,8 @@ class ServletSseMcpSyncServerTests extends AbstractMcpSyncServerTests { @Override - protected ServerMcpTransport createMcpTransport() { - return new HttpServletSseServerTransport(new ObjectMapper(), "/mcp/message"); + protected McpServerTransportProvider createMcpTransportProvider() { + return new HttpServletSseServerTransportProvider(new ObjectMapper(), "/mcp/message"); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerDeprecatedTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerDeprecatedTests.java new file mode 100644 index 00000000..db95db07 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerDeprecatedTests.java @@ -0,0 +1,25 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.server.transport.StdioServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.Timeout; + +/** + * Tests for {@link McpAsyncServer} using {@link StdioServerTransport}. + * + * @author Christian Tzolov + */ +@Deprecated +@Timeout(15) // Giving extra time beyond the client timeout +class StdioMcpAsyncServerDeprecatedTests extends AbstractMcpAsyncServerDeprecatedTests { + + @Override + protected ServerMcpTransport createMcpTransport() { + return new StdioServerTransport(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java index e933d638..27ff53c9 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java @@ -5,7 +5,8 @@ package io.modelcontextprotocol.server; import io.modelcontextprotocol.server.transport.StdioServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.StdioServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; /** @@ -17,8 +18,8 @@ class StdioMcpAsyncServerTests extends AbstractMcpAsyncServerTests { @Override - protected ServerMcpTransport createMcpTransport() { - return new StdioServerTransport(); + protected McpServerTransportProvider createMcpTransportProvider() { + return new StdioServerTransportProvider(); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerDeprecatedTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerDeprecatedTests.java new file mode 100644 index 00000000..149f7281 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerDeprecatedTests.java @@ -0,0 +1,25 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.server.transport.StdioServerTransport; +import io.modelcontextprotocol.spec.ServerMcpTransport; +import org.junit.jupiter.api.Timeout; + +/** + * Tests for {@link McpSyncServer} using {@link StdioServerTransport}. + * + * @author Christian Tzolov + */ +@Deprecated +@Timeout(15) // Giving extra time beyond the client timeout +class StdioMcpSyncServerDeprecatedTests extends AbstractMcpSyncServerDeprecatedTests { + + @Override + protected ServerMcpTransport createMcpTransport() { + return new StdioServerTransport(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java index d9350417..a71c3849 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java @@ -4,12 +4,12 @@ package io.modelcontextprotocol.server; -import io.modelcontextprotocol.server.transport.StdioServerTransport; -import io.modelcontextprotocol.spec.ServerMcpTransport; +import io.modelcontextprotocol.server.transport.StdioServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; import org.junit.jupiter.api.Timeout; /** - * Tests for {@link McpSyncServer} using {@link StdioServerTransport}. + * Tests for {@link McpSyncServer} using {@link StdioServerTransportProvider}. * * @author Christian Tzolov */ @@ -17,8 +17,8 @@ class StdioMcpSyncServerTests extends AbstractMcpSyncServerTests { @Override - protected ServerMcpTransport createMcpTransport() { - return new StdioServerTransport(); + protected McpServerTransportProvider createMcpTransportProvider() { + return new StdioServerTransportProvider(); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java deleted file mode 100644 index 0ab72a99..00000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java +++ /dev/null @@ -1,69 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -*/ -package io.modelcontextprotocol.server.transport; - -import java.io.IOException; -import java.io.InputStream; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.LinkedBlockingQueue; - -public class BlockingInputStream extends InputStream { - - private final BlockingQueue queue = new LinkedBlockingQueue<>(); - - private volatile boolean completed = false; - - private volatile boolean closed = false; - - @Override - public int read() throws IOException { - if (closed) { - throw new IOException("Stream is closed"); - } - - try { - Integer value = queue.poll(); - if (value == null) { - if (completed) { - return -1; - } - value = queue.take(); // Blocks until data is available - if (value == null && completed) { - return -1; - } - } - return value; - } - catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new IOException("Read interrupted", e); - } - } - - public void write(int b) { - if (!closed && !completed) { - queue.offer(b); - } - } - - public void write(byte[] data) { - if (!closed && !completed) { - for (byte b : data) { - queue.offer((int) b & 0xFF); - } - } - } - - public void complete() { - this.completed = true; - } - - @Override - public void close() { - this.closed = true; - this.completed = true; - this.queue.clear(); - } - -} \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java new file mode 100644 index 00000000..290141bb --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -0,0 +1,493 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server.transport; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.web.client.RestClient; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; +import static org.mockito.Mockito.mock; + +public class HttpServletSseServerTransportProviderIntegrationTests { + + private static final int PORT = 8185; + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private HttpServletSseServerTransportProvider mcpServerTransportProvider; + + McpClient.SyncSpec clientBuilder; + + private Tomcat tomcat; + + @BeforeEach + public void before() { + tomcat = new Tomcat(); + tomcat.setPort(PORT); + + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + Context context = tomcat.addContext("", baseDir); + + // Create and configure the transport provider + mcpServerTransportProvider = new HttpServletSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + + // Add transport servlet to Tomcat + org.apache.catalina.Wrapper wrapper = context.createWrapper(); + wrapper.setName("mcpServlet"); + wrapper.setServlet(mcpServerTransportProvider); + wrapper.setLoadOnStartup(1); + wrapper.setAsyncSupported(true); + context.addChild(wrapper); + context.addServletMappingDecoded("/*", "mcpServlet"); + + try { + var connector = tomcat.getConnector(); + connector.setAsyncTimeout(3000); + tomcat.start(); + assertThat(tomcat.getServer().getState() == LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + this.clientBuilder = McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT)); + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + // --------------------------------------- + // Sampling Tests + // --------------------------------------- + @Test + @Disabled + void testCreateMessageWithoutSamplingCapabilities() { + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); + + return Mono.just(mock(CallToolResult.class)); + }); + + McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); + + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build(); + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + } + } + + @Test + void testCreateMessageSuccess() throws InterruptedException { + + // Client + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var messages = List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message"))); + var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); + + var craeteMessageRequest = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, + McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), + Map.of()); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + @Test + void testRootsSuccess() { + List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Remove a root + mcpClient.removeRoot(roots.get(0).uri()); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); + }); + + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testRootsWithoutCapability() { + + McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + exchange.listRoots(); // try to list roots + + return mock(CallToolResult.class); + }); + + var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { + }).tools(tool).build(); + + // Create client without roots capability + // No roots capability + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build(); + + assertThat(mcpClient.initialize()).isNotNull(); + + // Attempt to list roots should fail + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testRootsNotifciationWithEmptyRootsList() { + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(List.of()) // Empty roots list + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testRootsWithMultipleHandlers() { + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef1 = new AtomicReference<>(); + AtomicReference> rootsRef2 = new AtomicReference<>(); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + assertThat(mcpClient.initialize()).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testRootsServerCloseWithActiveSubscription() { + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Close server while subscription is active + mcpServer.close(); + + // Verify client can handle server closure gracefully + mcpClient.close(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @Test + void testToolCallSuccess() { + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + var mcpClient = clientBuilder.build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testToolListChangeHandlingSuccess() { + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + AtomicReference> rootsRef = new AtomicReference<>(); + var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://github.com/modelcontextprotocol/specification/blob/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + rootsRef.set(toolsUpdate); + }).build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + mcpServer.notifyToolsListChanged(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); + }); + + // Remove a tool + mcpServer.removeTool("tool1"); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + + // Add a new tool + McpServerFeatures.SyncToolSpecification tool2 = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema), (exchange, request) -> callResponse); + + mcpServer.addTool(tool2); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testInitialize() { + + var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); + + var mcpClient = clientBuilder.build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.close(); + mcpServer.close(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java new file mode 100644 index 00000000..14987b5a --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java @@ -0,0 +1,227 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.InputStream; +import java.io.PrintStream; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link StdioServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Disabled +class StdioServerTransportProviderTests { + + private final PrintStream originalOut = System.out; + + private final PrintStream originalErr = System.err; + + private ByteArrayOutputStream testErr; + + private PrintStream testOutPrintStream; + + private StdioServerTransportProvider transportProvider; + + private ObjectMapper objectMapper; + + private McpServerSession.Factory sessionFactory; + + private McpServerSession mockSession; + + @BeforeEach + void setUp() { + testErr = new ByteArrayOutputStream(); + + testOutPrintStream = new PrintStream(testErr, true); + System.setOut(testOutPrintStream); + System.setErr(testOutPrintStream); + + objectMapper = new ObjectMapper(); + + // Create mocks for session factory and session + mockSession = mock(McpServerSession.class); + sessionFactory = mock(McpServerSession.Factory.class); + + // Configure mock behavior + when(sessionFactory.create(any(McpServerTransport.class))).thenReturn(mockSession); + when(mockSession.closeGracefully()).thenReturn(Mono.empty()); + when(mockSession.sendNotification(any(), any())).thenReturn(Mono.empty()); + + transportProvider = new StdioServerTransportProvider(objectMapper, System.in, testOutPrintStream); + } + + @AfterEach + void tearDown() { + if (transportProvider != null) { + transportProvider.closeGracefully().block(); + } + if (testOutPrintStream != null) { + testOutPrintStream.close(); + } + System.setOut(originalOut); + System.setErr(originalErr); + } + + @Test + void shouldCreateSessionWhenSessionFactoryIsSet() { + // Set session factory + transportProvider.setSessionFactory(sessionFactory); + + // Verify session was created with a transport + assertThat(testErr.toString()).doesNotContain("Error"); + } + + @Test + void shouldHandleIncomingMessages() throws Exception { + + String jsonMessage = "{\"jsonrpc\":\"2.0\",\"method\":\"test\",\"params\":{},\"id\":1}\n"; + InputStream stream = new ByteArrayInputStream(jsonMessage.getBytes(StandardCharsets.UTF_8)); + + transportProvider = new StdioServerTransportProvider(objectMapper, stream, System.out); + // Set up a real session to capture the message + AtomicReference capturedMessage = new AtomicReference<>(); + CountDownLatch messageLatch = new CountDownLatch(1); + + McpServerSession.Factory realSessionFactory = transport -> { + McpServerSession session = mock(McpServerSession.class); + when(session.handle(any())).thenAnswer(invocation -> { + capturedMessage.set(invocation.getArgument(0)); + messageLatch.countDown(); + return Mono.empty(); + }); + when(session.closeGracefully()).thenReturn(Mono.empty()); + return session; + }; + + // Set session factory + transportProvider.setSessionFactory(realSessionFactory); + + // Wait for the message to be processed using the latch + StepVerifier.create(Mono.fromCallable(() -> messageLatch.await(100, TimeUnit.SECONDS)).flatMap(success -> { + if (!success) { + return Mono.error(new AssertionError("Timeout waiting for message processing")); + } + return Mono.just(capturedMessage.get()); + })).assertNext(message -> { + assertThat(message).isNotNull(); + assertThat(message).isInstanceOf(McpSchema.JSONRPCRequest.class); + McpSchema.JSONRPCRequest request = (McpSchema.JSONRPCRequest) message; + assertThat(request.method()).isEqualTo("test"); + assertThat(request.id()).isEqualTo(1); + }).verifyComplete(); + } + + @Test + void shouldNotifyClients() { + // Set session factory + transportProvider.setSessionFactory(sessionFactory); + + // Send notification + String method = "testNotification"; + Map params = Map.of("key", "value"); + + StepVerifier.create(transportProvider.notifyClients(method, params)).verifyComplete(); + + // Error log should be empty + assertThat(testErr.toString()).doesNotContain("Error"); + } + + @Test + void shouldCloseGracefully() { + // Set session factory + transportProvider.setSessionFactory(sessionFactory); + + // Close gracefully + StepVerifier.create(transportProvider.closeGracefully()).verifyComplete(); + + // Error log should be empty + assertThat(testErr.toString()).doesNotContain("Error"); + } + + @Test + void shouldHandleMultipleCloseGracefullyCalls() { + // Set session factory + transportProvider.setSessionFactory(sessionFactory); + + // Close gracefully multiple times + StepVerifier + .create(transportProvider.closeGracefully() + .then(transportProvider.closeGracefully()) + .then(transportProvider.closeGracefully())) + .verifyComplete(); + + // Error log should be empty + assertThat(testErr.toString()).doesNotContain("Error"); + } + + @Test + void shouldHandleNotificationBeforeSessionFactoryIsSet() { + + transportProvider = new StdioServerTransportProvider(objectMapper); + // Send notification before setting session factory + StepVerifier.create(transportProvider.notifyClients("testNotification", Map.of("key", "value"))) + .verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class); + }); + } + + @Test + void shouldHandleInvalidJsonMessage() throws Exception { + + // Write an invalid JSON message to the input stream + String jsonMessage = "{invalid json}\n"; + InputStream stream = new ByteArrayInputStream(jsonMessage.getBytes(StandardCharsets.UTF_8)); + + transportProvider = new StdioServerTransportProvider(objectMapper, stream, testOutPrintStream); + + // Set up a session factory + transportProvider.setSessionFactory(sessionFactory); + + // Use StepVerifier with a timeout to wait for the error to be processed + StepVerifier + .create(Mono.delay(java.time.Duration.ofMillis(500)).then(Mono.fromCallable(() -> testErr.toString()))) + .assertNext(errorOutput -> assertThat(errorOutput).contains("Error processing inbound message")) + .verifyComplete(); + } + + @Test + void shouldHandleSessionClose() throws Exception { + // Set session factory + transportProvider.setSessionFactory(sessionFactory); + + // Close the transport provider + transportProvider.close(); + + // Verify session was closed + verify(mockSession).closeGracefully(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/DefaultMcpSessionTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java similarity index 90% rename from mcp/src/test/java/io/modelcontextprotocol/spec/DefaultMcpSessionTests.java rename to mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java index 9d011aff..79a1d0d9 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/DefaultMcpSessionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java @@ -22,14 +22,14 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; /** - * Test suite for {@link DefaultMcpSession} that verifies its JSON-RPC message handling, + * Test suite for {@link McpClientSession} that verifies its JSON-RPC message handling, * request-response correlation, and notification processing. * * @author Christian Tzolov */ -class DefaultMcpSessionTests { +class McpClientSessionTests { - private static final Logger logger = LoggerFactory.getLogger(DefaultMcpSessionTests.class); + private static final Logger logger = LoggerFactory.getLogger(McpClientSessionTests.class); private static final Duration TIMEOUT = Duration.ofSeconds(5); @@ -39,14 +39,14 @@ class DefaultMcpSessionTests { private static final String ECHO_METHOD = "echo"; - private DefaultMcpSession session; + private McpClientSession session; private MockMcpTransport transport; @BeforeEach void setUp() { transport = new MockMcpTransport(); - session = new DefaultMcpSession(TIMEOUT, transport, Map.of(), + session = new McpClientSession(TIMEOUT, transport, Map.of(), Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: " + params)))); } @@ -59,11 +59,11 @@ void tearDown() { @Test void testConstructorWithInvalidArguments() { - assertThatThrownBy(() -> new DefaultMcpSession(null, transport, Map.of(), Map.of())) + assertThatThrownBy(() -> new McpClientSession(null, transport, Map.of(), Map.of())) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("requstTimeout can not be null"); - assertThatThrownBy(() -> new DefaultMcpSession(TIMEOUT, null, Map.of(), Map.of())) + assertThatThrownBy(() -> new McpClientSession(TIMEOUT, null, Map.of(), Map.of())) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("transport can not be null"); } @@ -137,10 +137,10 @@ void testSendNotification() { @Test void testRequestHandling() { String echoMessage = "Hello MCP!"; - Map> requestHandlers = Map.of(ECHO_METHOD, + Map> requestHandlers = Map.of(ECHO_METHOD, params -> Mono.just(params)); transport = new MockMcpTransport(); - session = new DefaultMcpSession(TIMEOUT, transport, requestHandlers, Map.of()); + session = new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of()); // Simulate incoming request McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD, @@ -160,7 +160,7 @@ void testNotificationHandling() { Sinks.One receivedParams = Sinks.one(); transport = new MockMcpTransport(); - session = new DefaultMcpSession(TIMEOUT, transport, Map.of(), + session = new McpClientSession(TIMEOUT, transport, Map.of(), Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> receivedParams.tryEmitValue(params)))); // Simulate incoming notification from the server diff --git a/migration-0.8.0.md b/migration-0.8.0.md new file mode 100644 index 00000000..3ba29a10 --- /dev/null +++ b/migration-0.8.0.md @@ -0,0 +1,328 @@ +# MCP Java SDK Migration Guide: 0.7.0 to 0.8.0 + +This document outlines the breaking changes and provides guidance on how to migrate your code from version 0.7.0 to 0.8.0. + +The 0.8.0 refactoring introduces a session-based architecture for server-side MCP implementations. +It improves the SDK's ability to handle multiple concurrent client connections and provides an API better aligned with the MCP specification. +The main changes include: + +1. Introduction of a session-based architecture +2. New transport provider abstraction +3. Exchange objects for client interaction +4. Renamed and reorganized interfaces +5. Updated handler signatures + +## Breaking Changes + +### 1. Interface Renaming + +Several interfaces have been renamed to better reflect their roles: + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `ClientMcpTransport` | `McpClientTransport` | +| `ServerMcpTransport` | `McpServerTransport` | +| `DefaultMcpSession` | `McpClientSession`, `McpServerSession` | + +### 2. New Server Transport Architecture + +The most significant change is the introduction of the `McpServerTransportProvider` interface, which replaces direct usage of `ServerMcpTransport` when creating servers. This new pattern separates the concerns of: + +1. **Transport Provider**: Manages connections with clients and creates individual transports for each connection +2. **Server Transport**: Handles communication with a specific client connection + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `ServerMcpTransport` | `McpServerTransportProvider` + `McpServerTransport` | +| Direct transport usage | Session-based transport usage | + +#### Before (0.7.0): + +```java +// Create a transport +ServerMcpTransport transport = new WebFluxSseServerTransport(objectMapper, "/mcp/message"); + +// Create a server with the transport +McpServer.sync(transport) + .serverInfo("my-server", "1.0.0") + .build(); +``` + +#### After (0.8.0): + +```java +// Create a transport provider +McpServerTransportProvider transportProvider = new WebFluxSseServerTransportProvider(objectMapper, "/mcp/message"); + +// Create a server with the transport provider +McpServer.sync(transportProvider) + .serverInfo("my-server", "1.0.0") + .build(); +``` + +### 3. Handler Method Signature Changes + +Tool, resource, and prompt handlers now receive an additional `exchange` parameter that provides access to client capabilities and methods to interact with the client: + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `(args) -> result` | `(exchange, args) -> result` | + +The exchange objects (`McpAsyncServerExchange` and `McpSyncServerExchange`) provide context for the current session and access to session-specific operations. + +#### Before (0.7.0): + +```java +// Tool handler +.tool(calculatorTool, args -> new CallToolResult("Result: " + calculate(args))) + +// Resource handler +.resource(fileResource, req -> new ReadResourceResult(readFile(req))) + +// Prompt handler +.prompt(analysisPrompt, req -> new GetPromptResult("Analysis prompt")) +``` + +#### After (0.8.0): + +```java +// Tool handler +.tool(calculatorTool, (exchange, args) -> new CallToolResult("Result: " + calculate(args))) + +// Resource handler +.resource(fileResource, (exchange, req) -> new ReadResourceResult(readFile(req))) + +// Prompt handler +.prompt(analysisPrompt, (exchange, req) -> new GetPromptResult("Analysis prompt")) +``` + +### 4. Registration vs. Specification + +The naming convention for handlers has changed from "Registration" to "Specification": + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `AsyncToolRegistration` | `AsyncToolSpecification` | +| `SyncToolRegistration` | `SyncToolSpecification` | +| `AsyncResourceRegistration` | `AsyncResourceSpecification` | +| `SyncResourceRegistration` | `SyncResourceSpecification` | +| `AsyncPromptRegistration` | `AsyncPromptSpecification` | +| `SyncPromptRegistration` | `SyncPromptSpecification` | + +### 5. Roots Change Handler Updates + +The roots change handlers now receive an exchange parameter: + +#### Before (0.7.0): + +```java +.rootsChangeConsumers(List.of( + roots -> { + // Process roots + } +)) +``` + +#### After (0.8.0): + +```java +.rootsChangeHandlers(List.of( + (exchange, roots) -> { + // Process roots with access to exchange + } +)) +``` + +### 6. Server Creation Method Changes + +The `McpServer` factory methods now accept `McpServerTransportProvider` instead of `ServerMcpTransport`: + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `McpServer.async(ServerMcpTransport)` | `McpServer.async(McpServerTransportProvider)` | +| `McpServer.sync(ServerMcpTransport)` | `McpServer.sync(McpServerTransportProvider)` | + +The method names for creating servers have been updated: + +Root change handlers now receive an exchange object: + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `rootsChangeConsumers(List>>)` | `rootsChangeHandlers(List>>)` | +| `rootsChangeConsumer(Consumer>)` | `rootsChangeHandler(BiConsumer>)` | + +### 7. Direct Server Methods Moving to Exchange + +Several methods that were previously available directly on the server are now accessed through the exchange object: + +| 0.7.0 (Old) | 0.8.0 (New) | +|-------------|-------------| +| `server.listRoots()` | `exchange.listRoots()` | +| `server.createMessage()` | `exchange.createMessage()` | +| `server.getClientCapabilities()` | `exchange.getClientCapabilities()` | +| `server.getClientInfo()` | `exchange.getClientInfo()` | + +The direct methods are deprecated and will be removed in 0.9.0: + +- `McpSyncServer.listRoots()` +- `McpSyncServer.getClientCapabilities()` +- `McpSyncServer.getClientInfo()` +- `McpSyncServer.createMessage()` +- `McpAsyncServer.listRoots()` +- `McpAsyncServer.getClientCapabilities()` +- `McpAsyncServer.getClientInfo()` +- `McpAsyncServer.createMessage()` + +## Deprecation Notices + +The following components are deprecated in 0.8.0 and will be removed in 0.9.0: + +- `ClientMcpTransport` interface (use `McpClientTransport` instead) +- `ServerMcpTransport` interface (use `McpServerTransport` instead) +- `DefaultMcpSession` class (use `McpClientSession` instead) +- `WebFluxSseServerTransport` class (use `WebFluxSseServerTransportProvider` instead) +- `WebMvcSseServerTransport` class (use `WebMvcSseServerTransportProvider` instead) +- `StdioServerTransport` class (use `StdioServerTransportProvider` instead) +- All `*Registration` classes (use corresponding `*Specification` classes instead) +- Direct server methods for client interaction (use exchange object instead) + +## Migration Examples + +### Example 1: Creating a Server + +#### Before (0.7.0): + +```java +// Create a transport +ServerMcpTransport transport = new WebFluxSseServerTransport(objectMapper, "/mcp/message"); + +// Create a server with the transport +var server = McpServer.sync(transport) + .serverInfo("my-server", "1.0.0") + .tool(calculatorTool, args -> new CallToolResult("Result: " + calculate(args))) + .rootsChangeConsumers(List.of( + roots -> System.out.println("Roots changed: " + roots) + )) + .build(); + +// Get client capabilities directly from server +ClientCapabilities capabilities = server.getClientCapabilities(); +``` + +#### After (0.8.0): + +```java +// Create a transport provider +McpServerTransportProvider transportProvider = new WebFluxSseServerTransportProvider(objectMapper, "/mcp/message"); + +// Create a server with the transport provider +var server = McpServer.sync(transportProvider) + .serverInfo("my-server", "1.0.0") + .tool(calculatorTool, (exchange, args) -> { + // Get client capabilities from exchange + ClientCapabilities capabilities = exchange.getClientCapabilities(); + return new CallToolResult("Result: " + calculate(args)); + }) + .rootsChangeHandlers(List.of( + (exchange, roots) -> System.out.println("Roots changed: " + roots) + )) + .build(); +``` + +### Example 2: Implementing a Tool with Client Interaction + +#### Before (0.7.0): + +```java +McpServerFeatures.SyncToolRegistration tool = new McpServerFeatures.SyncToolRegistration( + new Tool("weather", "Get weather information", schema), + args -> { + String location = (String) args.get("location"); + // Cannot interact with client from here + return new CallToolResult("Weather for " + location + ": Sunny"); + } +); + +var server = McpServer.sync(transport) + .tools(tool) + .build(); + +// Separate call to create a message +CreateMessageResult result = server.createMessage(new CreateMessageRequest(...)); +``` + +#### After (0.8.0): + +```java +McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification( + new Tool("weather", "Get weather information", schema), + (exchange, args) -> { + String location = (String) args.get("location"); + + // Can interact with client directly from the tool handler + CreateMessageResult result = exchange.createMessage(new CreateMessageRequest(...)); + + return new CallToolResult("Weather for " + location + ": " + result.content()); + } +); + +var server = McpServer.sync(transportProvider) + .tools(tool) + .build(); +``` + +### Example 3: Converting Existing Registration Classes + +If you have custom implementations of the registration classes, you can convert them to the new specification classes: + +#### Before (0.7.0): + +```java +McpServerFeatures.AsyncToolRegistration toolReg = new McpServerFeatures.AsyncToolRegistration( + tool, + args -> Mono.just(new CallToolResult("Result")) +); + +McpServerFeatures.AsyncResourceRegistration resourceReg = new McpServerFeatures.AsyncResourceRegistration( + resource, + req -> Mono.just(new ReadResourceResult(List.of())) +); +``` + +#### After (0.8.0): + +```java +// Option 1: Create new specification directly +McpServerFeatures.AsyncToolSpecification toolSpec = new McpServerFeatures.AsyncToolSpecification( + tool, + (exchange, args) -> Mono.just(new CallToolResult("Result")) +); + +// Option 2: Convert from existing registration (during transition) +McpServerFeatures.AsyncToolRegistration oldToolReg = /* existing registration */; +McpServerFeatures.AsyncToolSpecification toolSpec = oldToolReg.toSpecification(); + +// Similarly for resources +McpServerFeatures.AsyncResourceSpecification resourceSpec = new McpServerFeatures.AsyncResourceSpecification( + resource, + (exchange, req) -> Mono.just(new ReadResourceResult(List.of())) +); +``` + +## Architecture Changes + +### Session-Based Architecture + +In 0.8.0, the MCP Java SDK introduces a session-based architecture where each client connection has its own session. This allows for better isolation between clients and more efficient resource management. + +The `McpServerTransportProvider` is responsible for creating `McpServerTransport` instances for each session, and the `McpServerSession` manages the communication with a specific client. + +### Exchange Objects + +The new exchange objects (`McpAsyncServerExchange` and `McpSyncServerExchange`) provide access to client-specific information and methods. They are passed to handler functions as the first parameter, allowing handlers to interact with the specific client that made the request. + +## Conclusion + +The changes in version 0.8.0 represent a significant architectural improvement to the MCP Java SDK. While they require some code changes, the new design provides a more flexible and maintainable foundation for building MCP applications. + +For assistance with migration or to report issues, please open an issue on the GitHub repository.