Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(client): Improve initialization state handling in McpAsyncClient #39

Merged
merged 6 commits into from
Mar 13, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 117 additions & 69 deletions mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;

import com.fasterxml.jackson.core.type.TypeReference;
Expand All @@ -35,6 +37,7 @@
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;

/**
* The Model Context Protocol (MCP) client implementation that provides asynchronous
Expand Down Expand Up @@ -79,6 +82,12 @@ public class McpAsyncClient {
private static TypeReference<Void> VOID_TYPE_REFERENCE = new TypeReference<>() {
};

protected final Sinks.One<McpSchema.InitializeResult> initializedSink = Sinks.one();

private AtomicBoolean initialized = new AtomicBoolean(false);

private final Duration initializedTimeout;

/**
* The MCP session implementation that manages bidirectional JSON-RPC communication
* between clients and servers.
Expand Down Expand Up @@ -149,6 +158,7 @@ public class McpAsyncClient {
this.clientCapabilities = features.clientCapabilities();
this.transport = transport;
this.roots = new ConcurrentHashMap<>(features.roots());
this.initializedTimeout = requestTimeout.multipliedBy(2);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be this value and undocumented? I fear it would be worthwhile to add logging with a note after how long the client gave up. Also, this will have to be revisited once reconnects are considered.


// Request Handlers
Map<String, RequestHandler<?>> requestHandlers = new HashMap<>();
Expand Down Expand Up @@ -253,8 +263,8 @@ public Mono<McpSchema.InitializeResult> initialize() {

McpSchema.InitializeRequest initializeRequest = new McpSchema.InitializeRequest(// @formatter:off
latestVersion,
this.clientCapabilities,
this.clientInfo); // @formatter:on
this.clientCapabilities,
this.clientInfo); // @formatter:on

Mono<McpSchema.InitializeResult> result = this.mcpSession.sendRequest(McpSchema.METHOD_INITIALIZE,
initializeRequest, new TypeReference<McpSchema.InitializeResult>() {
Expand All @@ -273,10 +283,11 @@ public Mono<McpSchema.InitializeResult> initialize() {
return Mono.error(new McpError(
"Unsupported protocol version from the server: " + initializeResult.protocolVersion()));
}
else {
return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null)
.thenReturn(initializeResult);
}

return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null).doOnSuccess(v -> {
this.initialized.set(true);
this.initializedSink.tryEmitValue(initializeResult);
}).thenReturn(initializeResult);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, this will make the client more robust :)

});
}

Expand All @@ -301,7 +312,7 @@ public McpSchema.Implementation getServerInfo() {
* @return true if the client-server connection is initialized
*/
public boolean isInitialized() {
return this.serverCapabilities != null;
return this.initialized.get();
}

/**
Expand Down Expand Up @@ -335,6 +346,26 @@ public Mono<Void> closeGracefully() {
return this.mcpSession.closeGracefully();
}

// --------------------------
// Utility Methods
// --------------------------
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a not some "utility" but a fundamental prerequisite to almost any operation, perhaps we can group it under "initialization methods"?


/**
* Utility method to handle the common pattern of checking initialization before
* executing an operation.
* @param <T> The type of the result Mono
* @param errorMessage The error message to use if the client is not initialized
* @param operation The operation to execute if the client is initialized
* @return A Mono that completes with the result of the operation
*/
private <T> Mono<T> withInitializationCheck(String errorMessage,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Client must be initialized before " is repeated in each use, consider replacing errorMessage with action and just specify what's about to happen, e.g. action="pinging the server".

Function<McpSchema.InitializeResult, Mono<T>> operation) {
return this.initializedSink.asMono()
.timeout(this.initializedTimeout)
.onErrorResume(TimeoutException.class, ex -> Mono.error(new McpError(errorMessage)))
.flatMap(operation);
}

// --------------------------
// Basic Utilites
// --------------------------
Expand All @@ -344,8 +375,10 @@ public Mono<Void> closeGracefully() {
* @return A Mono that completes with the server's ping response
*/
public Mono<Object> ping() {
return this.mcpSession.sendRequest(McpSchema.METHOD_PING, null, new TypeReference<Object>() {
});
return withInitializationCheck("Client must be initialized before pinging the server",
initializedResult -> this.mcpSession.sendRequest(McpSchema.METHOD_PING, null,
new TypeReference<Object>() {
}));
}

// --------------------------
Expand Down Expand Up @@ -375,7 +408,12 @@ public Mono<Void> addRoot(Root root) {
logger.debug("Added root: {}", root);

if (this.clientCapabilities.roots().listChanged()) {
return this.rootsListChangedNotification();
if (this.isInitialized()) {
return this.rootsListChangedNotification();
}
else {
logger.warn("Client is not initialized, ignore sending a roots list changed notification");
}
}
return Mono.empty();
}
Expand All @@ -400,7 +438,13 @@ public Mono<Void> removeRoot(String rootUri) {
if (removed != null) {
logger.debug("Removed Root: {}", rootUri);
if (this.clientCapabilities.roots().listChanged()) {
return this.rootsListChangedNotification();
if (this.isInitialized()) {
return this.rootsListChangedNotification();
}
else {
logger.warn("Client is not initialized, ignore sending a roots list changed notification");
}

}
return Mono.empty();
}
Expand All @@ -413,7 +457,8 @@ public Mono<Void> removeRoot(String rootUri) {
* @return A Mono that completes when the notification is sent
*/
public Mono<Void> rootsListChangedNotification() {
return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED);
return this.withInitializationCheck("Client must be initialized before sending roots list changed notification",
initResult -> this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED));
}

private RequestHandler<McpSchema.ListRootsResult> rootsListRequestHandler() {
Expand Down Expand Up @@ -464,13 +509,12 @@ private RequestHandler<CreateMessageResult> samplingCreateMessageHandler() {
* (false/absent)
*/
public Mono<McpSchema.CallToolResult> callTool(McpSchema.CallToolRequest callToolRequest) {
if (!this.isInitialized()) {
return Mono.error(new McpError("Client must be initialized before calling tools"));
}
if (this.serverCapabilities.tools() == null) {
return Mono.error(new McpError("Server does not provide tools capability"));
}
return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF);
return withInitializationCheck("Client must be initialized before calling tools", initializedResult -> {
if (this.serverCapabilities.tools() == null) {
return Mono.error(new McpError("Server does not provide tools capability"));
}
return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF);
});
}

/**
Expand All @@ -491,14 +535,13 @@ public Mono<McpSchema.ListToolsResult> listTools() {
* Optional cursor for pagination if more tools are available
*/
public Mono<McpSchema.ListToolsResult> listTools(String cursor) {
if (!this.isInitialized()) {
return Mono.error(new McpError("Client must be initialized before listing tools"));
}
if (this.serverCapabilities.tools() == null) {
return Mono.error(new McpError("Server does not provide tools capability"));
}
return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor),
LIST_TOOLS_RESULT_TYPE_REF);
return withInitializationCheck("Client must be initialized before listing tools", initializedResult -> {
if (this.serverCapabilities.tools() == null) {
return Mono.error(new McpError("Server does not provide tools capability"));
}
return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor),
LIST_TOOLS_RESULT_TYPE_REF);
});
}

/**
Expand All @@ -516,13 +559,14 @@ public Mono<McpSchema.ListToolsResult> listTools(String cursor) {
private NotificationHandler asyncToolsChangeNotificationHandler(
List<Function<List<McpSchema.Tool>, Mono<Void>>> toolsChangeConsumers) {
// TODO: params are not used yet
return params -> listTools().flatMap(listToolsResult -> Flux.fromIterable(toolsChangeConsumers)
.flatMap(consumer -> consumer.apply(listToolsResult.tools()))
.onErrorResume(error -> {
logger.error("Error handling tools list change notification", error);
return Mono.empty();
})
.then());
return params -> this.listTools()
.flatMap(listToolsResult -> Flux.fromIterable(toolsChangeConsumers)
.flatMap(consumer -> consumer.apply(listToolsResult.tools()))
.onErrorResume(error -> {
logger.error("Error handling tools list change notification", error);
return Mono.empty();
})
.then());
}

// --------------------------
Expand Down Expand Up @@ -552,14 +596,13 @@ public Mono<McpSchema.ListResourcesResult> listResources() {
* @return A Mono that completes with the list of resources result
*/
public Mono<McpSchema.ListResourcesResult> listResources(String cursor) {
if (!this.isInitialized()) {
return Mono.error(new McpError("Client must be initialized before listing resources"));
}
if (this.serverCapabilities.resources() == null) {
return Mono.error(new McpError("Server does not provide the resources capability"));
}
return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_LIST, new McpSchema.PaginatedRequest(cursor),
LIST_RESOURCES_RESULT_TYPE_REF);
return withInitializationCheck("Client must be initialized before listing resources", initializedResult -> {
if (this.serverCapabilities.resources() == null) {
return Mono.error(new McpError("Server does not provide the resources capability"));
}
return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_LIST, new McpSchema.PaginatedRequest(cursor),
LIST_RESOURCES_RESULT_TYPE_REF);
});
}

/**
Expand All @@ -577,14 +620,13 @@ public Mono<McpSchema.ReadResourceResult> readResource(McpSchema.Resource resour
* @return A Mono that completes with the resource content
*/
public Mono<McpSchema.ReadResourceResult> readResource(McpSchema.ReadResourceRequest readResourceRequest) {
if (!this.isInitialized()) {
return Mono.error(new McpError("Client must be initialized before reading resources"));
}
if (this.serverCapabilities.resources() == null) {
return Mono.error(new McpError("Server does not provide the resources capability"));
}
return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_READ, readResourceRequest,
READ_RESOURCE_RESULT_TYPE_REF);
return withInitializationCheck("Client must be initialized before reading resources", initializedResult -> {
if (this.serverCapabilities.resources() == null) {
return Mono.error(new McpError("Server does not provide the resources capability"));
}
return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_READ, readResourceRequest,
READ_RESOURCE_RESULT_TYPE_REF);
});
}

/**
Expand All @@ -607,14 +649,14 @@ public Mono<McpSchema.ListResourceTemplatesResult> listResourceTemplates() {
* @return A Mono that completes with the list of resource templates result
*/
public Mono<McpSchema.ListResourceTemplatesResult> listResourceTemplates(String cursor) {
if (!this.isInitialized()) {
return Mono.error(new McpError("Client must be initialized before listing resource templates"));
}
if (this.serverCapabilities.resources() == null) {
return Mono.error(new McpError("Server does not provide the resources capability"));
}
return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST,
new McpSchema.PaginatedRequest(cursor), LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF);
return withInitializationCheck("Client must be initialized before listing resource templates",
initializedResult -> {
if (this.serverCapabilities.resources() == null) {
return Mono.error(new McpError("Server does not provide the resources capability"));
}
return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST,
new McpSchema.PaginatedRequest(cursor), LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF);
});
}

/**
Expand All @@ -628,7 +670,9 @@ public Mono<McpSchema.ListResourceTemplatesResult> listResourceTemplates(String
* @return A Mono that completes when the subscription is complete
*/
public Mono<Void> subscribeResource(McpSchema.SubscribeRequest subscribeRequest) {
return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_SUBSCRIBE, subscribeRequest, VOID_TYPE_REFERENCE);
return withInitializationCheck("Client must be initialized before subscribing to resources",
initializedResult -> this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_SUBSCRIBE, subscribeRequest,
VOID_TYPE_REFERENCE));
}

/**
Expand All @@ -638,8 +682,9 @@ public Mono<Void> subscribeResource(McpSchema.SubscribeRequest subscribeRequest)
* @return A Mono that completes when the unsubscription is complete
*/
public Mono<Void> unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) {
return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_UNSUBSCRIBE, unsubscribeRequest,
VOID_TYPE_REFERENCE);
return withInitializationCheck("Client must be initialized before unsubscribing from resources",
initializedResult -> this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_UNSUBSCRIBE,
unsubscribeRequest, VOID_TYPE_REFERENCE));
}

private NotificationHandler asyncResourcesChangeNotificationHandler(
Expand Down Expand Up @@ -676,8 +721,9 @@ public Mono<ListPromptsResult> listPrompts() {
* @return A Mono that completes with the list of prompts result
*/
public Mono<ListPromptsResult> listPrompts(String cursor) {
return this.mcpSession.sendRequest(McpSchema.METHOD_PROMPT_LIST, new PaginatedRequest(cursor),
LIST_PROMPTS_RESULT_TYPE_REF);
return withInitializationCheck("Client must be initialized before listing prompts",
initializedResult -> this.mcpSession.sendRequest(McpSchema.METHOD_PROMPT_LIST,
new PaginatedRequest(cursor), LIST_PROMPTS_RESULT_TYPE_REF));
}

/**
Expand All @@ -686,7 +732,9 @@ public Mono<ListPromptsResult> listPrompts(String cursor) {
* @return A Mono that completes with the get prompt result
*/
public Mono<GetPromptResult> getPrompt(GetPromptRequest getPromptRequest) {
return this.mcpSession.sendRequest(McpSchema.METHOD_PROMPT_GET, getPromptRequest, GET_PROMPT_RESULT_TYPE_REF);
return withInitializationCheck("Client must be initialized before getting prompts",
initializedResult -> this.mcpSession.sendRequest(McpSchema.METHOD_PROMPT_GET, getPromptRequest,
GET_PROMPT_RESULT_TYPE_REF));
}

private NotificationHandler asyncPromptsChangeNotificationHandler(
Expand Down Expand Up @@ -732,12 +780,12 @@ private NotificationHandler asyncLoggingNotificationHandler(
public Mono<Void> setLoggingLevel(LoggingLevel loggingLevel) {
Assert.notNull(loggingLevel, "Logging level must not be null");

String levelName = this.transport.unmarshalFrom(loggingLevel, new TypeReference<String>() {
return withInitializationCheck("Client must be initialized before setting logging level", initializedResult -> {
String levelName = this.transport.unmarshalFrom(loggingLevel, new TypeReference<String>() {
});
Map<String, Object> params = Map.of("level", levelName);
return this.mcpSession.sendNotification(McpSchema.METHOD_LOGGING_SET_LEVEL, params);
});

Map<String, Object> params = Map.of("level", levelName);

return this.mcpSession.sendNotification(McpSchema.METHOD_LOGGING_SET_LEVEL, params);
}

/**
Expand Down