-
Notifications
You must be signed in to change notification settings - Fork 176
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(client): Improve initialization state handling in McpAsyncClient #39
Changes from 1 commit
b8ae1e3
ffe5448
1970cc6
8f0b24b
9105c31
6345bf4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,8 @@ | |
import java.util.List; | ||
import java.util.Map; | ||
import java.util.concurrent.ConcurrentHashMap; | ||
import java.util.concurrent.TimeoutException; | ||
import java.util.concurrent.atomic.AtomicBoolean; | ||
import java.util.function.Function; | ||
|
||
import com.fasterxml.jackson.core.type.TypeReference; | ||
|
@@ -35,6 +37,7 @@ | |
import org.slf4j.LoggerFactory; | ||
import reactor.core.publisher.Flux; | ||
import reactor.core.publisher.Mono; | ||
import reactor.core.publisher.Sinks; | ||
|
||
/** | ||
* The Model Context Protocol (MCP) client implementation that provides asynchronous | ||
|
@@ -79,6 +82,12 @@ public class McpAsyncClient { | |
private static TypeReference<Void> VOID_TYPE_REFERENCE = new TypeReference<>() { | ||
}; | ||
|
||
protected final Sinks.One<McpSchema.InitializeResult> initializedSink = Sinks.one(); | ||
|
||
private AtomicBoolean initialized = new AtomicBoolean(false); | ||
|
||
private final Duration initializedTimeout; | ||
|
||
/** | ||
* The MCP session implementation that manages bidirectional JSON-RPC communication | ||
* between clients and servers. | ||
|
@@ -149,6 +158,7 @@ public class McpAsyncClient { | |
this.clientCapabilities = features.clientCapabilities(); | ||
this.transport = transport; | ||
this.roots = new ConcurrentHashMap<>(features.roots()); | ||
this.initializedTimeout = requestTimeout.multipliedBy(2); | ||
|
||
// Request Handlers | ||
Map<String, RequestHandler<?>> requestHandlers = new HashMap<>(); | ||
|
@@ -253,8 +263,8 @@ public Mono<McpSchema.InitializeResult> initialize() { | |
|
||
McpSchema.InitializeRequest initializeRequest = new McpSchema.InitializeRequest(// @formatter:off | ||
latestVersion, | ||
this.clientCapabilities, | ||
this.clientInfo); // @formatter:on | ||
this.clientCapabilities, | ||
this.clientInfo); // @formatter:on | ||
|
||
Mono<McpSchema.InitializeResult> result = this.mcpSession.sendRequest(McpSchema.METHOD_INITIALIZE, | ||
initializeRequest, new TypeReference<McpSchema.InitializeResult>() { | ||
|
@@ -273,10 +283,11 @@ public Mono<McpSchema.InitializeResult> initialize() { | |
return Mono.error(new McpError( | ||
"Unsupported protocol version from the server: " + initializeResult.protocolVersion())); | ||
} | ||
else { | ||
return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null) | ||
.thenReturn(initializeResult); | ||
} | ||
|
||
return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null).doOnSuccess(v -> { | ||
this.initialized.set(true); | ||
this.initializedSink.tryEmitValue(initializeResult); | ||
}).thenReturn(initializeResult); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Awesome, this will make the client more robust :) |
||
}); | ||
} | ||
|
||
|
@@ -301,7 +312,7 @@ public McpSchema.Implementation getServerInfo() { | |
* @return true if the client-server connection is initialized | ||
*/ | ||
public boolean isInitialized() { | ||
return this.serverCapabilities != null; | ||
return this.initialized.get(); | ||
} | ||
|
||
/** | ||
|
@@ -335,6 +346,26 @@ public Mono<Void> closeGracefully() { | |
return this.mcpSession.closeGracefully(); | ||
} | ||
|
||
// -------------------------- | ||
// Utility Methods | ||
// -------------------------- | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a not some "utility" but a fundamental prerequisite to almost any operation, perhaps we can group it under "initialization methods"? |
||
|
||
/** | ||
* Utility method to handle the common pattern of checking initialization before | ||
* executing an operation. | ||
* @param <T> The type of the result Mono | ||
* @param errorMessage The error message to use if the client is not initialized | ||
* @param operation The operation to execute if the client is initialized | ||
* @return A Mono that completes with the result of the operation | ||
*/ | ||
private <T> Mono<T> withInitializationCheck(String errorMessage, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Client must be initialized before " is repeated in each use, consider replacing |
||
Function<McpSchema.InitializeResult, Mono<T>> operation) { | ||
return this.initializedSink.asMono() | ||
.timeout(this.initializedTimeout) | ||
.onErrorResume(TimeoutException.class, ex -> Mono.error(new McpError(errorMessage))) | ||
.flatMap(operation); | ||
} | ||
|
||
// -------------------------- | ||
// Basic Utilites | ||
// -------------------------- | ||
|
@@ -344,8 +375,10 @@ public Mono<Void> closeGracefully() { | |
* @return A Mono that completes with the server's ping response | ||
*/ | ||
public Mono<Object> ping() { | ||
return this.mcpSession.sendRequest(McpSchema.METHOD_PING, null, new TypeReference<Object>() { | ||
}); | ||
return withInitializationCheck("Client must be initialized before pinging the server", | ||
initializedResult -> this.mcpSession.sendRequest(McpSchema.METHOD_PING, null, | ||
new TypeReference<Object>() { | ||
})); | ||
} | ||
|
||
// -------------------------- | ||
|
@@ -375,7 +408,12 @@ public Mono<Void> addRoot(Root root) { | |
logger.debug("Added root: {}", root); | ||
|
||
if (this.clientCapabilities.roots().listChanged()) { | ||
return this.rootsListChangedNotification(); | ||
if (this.isInitialized()) { | ||
return this.rootsListChangedNotification(); | ||
} | ||
else { | ||
logger.warn("Client is not initialized, ignore sending a roots list changed notification"); | ||
} | ||
} | ||
return Mono.empty(); | ||
} | ||
|
@@ -400,7 +438,13 @@ public Mono<Void> removeRoot(String rootUri) { | |
if (removed != null) { | ||
logger.debug("Removed Root: {}", rootUri); | ||
if (this.clientCapabilities.roots().listChanged()) { | ||
return this.rootsListChangedNotification(); | ||
if (this.isInitialized()) { | ||
return this.rootsListChangedNotification(); | ||
} | ||
else { | ||
logger.warn("Client is not initialized, ignore sending a roots list changed notification"); | ||
} | ||
|
||
} | ||
return Mono.empty(); | ||
} | ||
|
@@ -413,7 +457,8 @@ public Mono<Void> removeRoot(String rootUri) { | |
* @return A Mono that completes when the notification is sent | ||
*/ | ||
public Mono<Void> rootsListChangedNotification() { | ||
return this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED); | ||
return this.withInitializationCheck("Client must be initialized before sending roots list changed notification", | ||
initResult -> this.mcpSession.sendNotification(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED)); | ||
} | ||
|
||
private RequestHandler<McpSchema.ListRootsResult> rootsListRequestHandler() { | ||
|
@@ -464,13 +509,12 @@ private RequestHandler<CreateMessageResult> samplingCreateMessageHandler() { | |
* (false/absent) | ||
*/ | ||
public Mono<McpSchema.CallToolResult> callTool(McpSchema.CallToolRequest callToolRequest) { | ||
if (!this.isInitialized()) { | ||
return Mono.error(new McpError("Client must be initialized before calling tools")); | ||
} | ||
if (this.serverCapabilities.tools() == null) { | ||
return Mono.error(new McpError("Server does not provide tools capability")); | ||
} | ||
return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF); | ||
return withInitializationCheck("Client must be initialized before calling tools", initializedResult -> { | ||
if (this.serverCapabilities.tools() == null) { | ||
return Mono.error(new McpError("Server does not provide tools capability")); | ||
} | ||
return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF); | ||
}); | ||
} | ||
|
||
/** | ||
|
@@ -491,14 +535,13 @@ public Mono<McpSchema.ListToolsResult> listTools() { | |
* Optional cursor for pagination if more tools are available | ||
*/ | ||
public Mono<McpSchema.ListToolsResult> listTools(String cursor) { | ||
if (!this.isInitialized()) { | ||
return Mono.error(new McpError("Client must be initialized before listing tools")); | ||
} | ||
if (this.serverCapabilities.tools() == null) { | ||
return Mono.error(new McpError("Server does not provide tools capability")); | ||
} | ||
return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor), | ||
LIST_TOOLS_RESULT_TYPE_REF); | ||
return withInitializationCheck("Client must be initialized before listing tools", initializedResult -> { | ||
if (this.serverCapabilities.tools() == null) { | ||
return Mono.error(new McpError("Server does not provide tools capability")); | ||
} | ||
return this.mcpSession.sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor), | ||
LIST_TOOLS_RESULT_TYPE_REF); | ||
}); | ||
} | ||
|
||
/** | ||
|
@@ -516,13 +559,14 @@ public Mono<McpSchema.ListToolsResult> listTools(String cursor) { | |
private NotificationHandler asyncToolsChangeNotificationHandler( | ||
List<Function<List<McpSchema.Tool>, Mono<Void>>> toolsChangeConsumers) { | ||
// TODO: params are not used yet | ||
return params -> listTools().flatMap(listToolsResult -> Flux.fromIterable(toolsChangeConsumers) | ||
.flatMap(consumer -> consumer.apply(listToolsResult.tools())) | ||
.onErrorResume(error -> { | ||
logger.error("Error handling tools list change notification", error); | ||
return Mono.empty(); | ||
}) | ||
.then()); | ||
return params -> this.listTools() | ||
.flatMap(listToolsResult -> Flux.fromIterable(toolsChangeConsumers) | ||
.flatMap(consumer -> consumer.apply(listToolsResult.tools())) | ||
.onErrorResume(error -> { | ||
logger.error("Error handling tools list change notification", error); | ||
return Mono.empty(); | ||
}) | ||
.then()); | ||
} | ||
|
||
// -------------------------- | ||
|
@@ -552,14 +596,13 @@ public Mono<McpSchema.ListResourcesResult> listResources() { | |
* @return A Mono that completes with the list of resources result | ||
*/ | ||
public Mono<McpSchema.ListResourcesResult> listResources(String cursor) { | ||
if (!this.isInitialized()) { | ||
return Mono.error(new McpError("Client must be initialized before listing resources")); | ||
} | ||
if (this.serverCapabilities.resources() == null) { | ||
return Mono.error(new McpError("Server does not provide the resources capability")); | ||
} | ||
return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_LIST, new McpSchema.PaginatedRequest(cursor), | ||
LIST_RESOURCES_RESULT_TYPE_REF); | ||
return withInitializationCheck("Client must be initialized before listing resources", initializedResult -> { | ||
if (this.serverCapabilities.resources() == null) { | ||
return Mono.error(new McpError("Server does not provide the resources capability")); | ||
} | ||
return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_LIST, new McpSchema.PaginatedRequest(cursor), | ||
LIST_RESOURCES_RESULT_TYPE_REF); | ||
}); | ||
} | ||
|
||
/** | ||
|
@@ -577,14 +620,13 @@ public Mono<McpSchema.ReadResourceResult> readResource(McpSchema.Resource resour | |
* @return A Mono that completes with the resource content | ||
*/ | ||
public Mono<McpSchema.ReadResourceResult> readResource(McpSchema.ReadResourceRequest readResourceRequest) { | ||
if (!this.isInitialized()) { | ||
return Mono.error(new McpError("Client must be initialized before reading resources")); | ||
} | ||
if (this.serverCapabilities.resources() == null) { | ||
return Mono.error(new McpError("Server does not provide the resources capability")); | ||
} | ||
return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_READ, readResourceRequest, | ||
READ_RESOURCE_RESULT_TYPE_REF); | ||
return withInitializationCheck("Client must be initialized before reading resources", initializedResult -> { | ||
if (this.serverCapabilities.resources() == null) { | ||
return Mono.error(new McpError("Server does not provide the resources capability")); | ||
} | ||
return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_READ, readResourceRequest, | ||
READ_RESOURCE_RESULT_TYPE_REF); | ||
}); | ||
} | ||
|
||
/** | ||
|
@@ -607,14 +649,14 @@ public Mono<McpSchema.ListResourceTemplatesResult> listResourceTemplates() { | |
* @return A Mono that completes with the list of resource templates result | ||
*/ | ||
public Mono<McpSchema.ListResourceTemplatesResult> listResourceTemplates(String cursor) { | ||
if (!this.isInitialized()) { | ||
return Mono.error(new McpError("Client must be initialized before listing resource templates")); | ||
} | ||
if (this.serverCapabilities.resources() == null) { | ||
return Mono.error(new McpError("Server does not provide the resources capability")); | ||
} | ||
return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, | ||
new McpSchema.PaginatedRequest(cursor), LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF); | ||
return withInitializationCheck("Client must be initialized before listing resource templates", | ||
initializedResult -> { | ||
if (this.serverCapabilities.resources() == null) { | ||
return Mono.error(new McpError("Server does not provide the resources capability")); | ||
} | ||
return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, | ||
new McpSchema.PaginatedRequest(cursor), LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF); | ||
}); | ||
} | ||
|
||
/** | ||
|
@@ -628,7 +670,9 @@ public Mono<McpSchema.ListResourceTemplatesResult> listResourceTemplates(String | |
* @return A Mono that completes when the subscription is complete | ||
*/ | ||
public Mono<Void> subscribeResource(McpSchema.SubscribeRequest subscribeRequest) { | ||
return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_SUBSCRIBE, subscribeRequest, VOID_TYPE_REFERENCE); | ||
return withInitializationCheck("Client must be initialized before subscribing to resources", | ||
initializedResult -> this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_SUBSCRIBE, subscribeRequest, | ||
VOID_TYPE_REFERENCE)); | ||
} | ||
|
||
/** | ||
|
@@ -638,8 +682,9 @@ public Mono<Void> subscribeResource(McpSchema.SubscribeRequest subscribeRequest) | |
* @return A Mono that completes when the unsubscription is complete | ||
*/ | ||
public Mono<Void> unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) { | ||
return this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_UNSUBSCRIBE, unsubscribeRequest, | ||
VOID_TYPE_REFERENCE); | ||
return withInitializationCheck("Client must be initialized before unsubscribing from resources", | ||
initializedResult -> this.mcpSession.sendRequest(McpSchema.METHOD_RESOURCES_UNSUBSCRIBE, | ||
unsubscribeRequest, VOID_TYPE_REFERENCE)); | ||
} | ||
|
||
private NotificationHandler asyncResourcesChangeNotificationHandler( | ||
|
@@ -676,8 +721,9 @@ public Mono<ListPromptsResult> listPrompts() { | |
* @return A Mono that completes with the list of prompts result | ||
*/ | ||
public Mono<ListPromptsResult> listPrompts(String cursor) { | ||
return this.mcpSession.sendRequest(McpSchema.METHOD_PROMPT_LIST, new PaginatedRequest(cursor), | ||
LIST_PROMPTS_RESULT_TYPE_REF); | ||
return withInitializationCheck("Client must be initialized before listing prompts", | ||
initializedResult -> this.mcpSession.sendRequest(McpSchema.METHOD_PROMPT_LIST, | ||
new PaginatedRequest(cursor), LIST_PROMPTS_RESULT_TYPE_REF)); | ||
} | ||
|
||
/** | ||
|
@@ -686,7 +732,9 @@ public Mono<ListPromptsResult> listPrompts(String cursor) { | |
* @return A Mono that completes with the get prompt result | ||
*/ | ||
public Mono<GetPromptResult> getPrompt(GetPromptRequest getPromptRequest) { | ||
return this.mcpSession.sendRequest(McpSchema.METHOD_PROMPT_GET, getPromptRequest, GET_PROMPT_RESULT_TYPE_REF); | ||
return withInitializationCheck("Client must be initialized before getting prompts", | ||
initializedResult -> this.mcpSession.sendRequest(McpSchema.METHOD_PROMPT_GET, getPromptRequest, | ||
GET_PROMPT_RESULT_TYPE_REF)); | ||
} | ||
|
||
private NotificationHandler asyncPromptsChangeNotificationHandler( | ||
|
@@ -732,12 +780,12 @@ private NotificationHandler asyncLoggingNotificationHandler( | |
public Mono<Void> setLoggingLevel(LoggingLevel loggingLevel) { | ||
Assert.notNull(loggingLevel, "Logging level must not be null"); | ||
|
||
String levelName = this.transport.unmarshalFrom(loggingLevel, new TypeReference<String>() { | ||
return withInitializationCheck("Client must be initialized before setting logging level", initializedResult -> { | ||
String levelName = this.transport.unmarshalFrom(loggingLevel, new TypeReference<String>() { | ||
}); | ||
Map<String, Object> params = Map.of("level", levelName); | ||
return this.mcpSession.sendNotification(McpSchema.METHOD_LOGGING_SET_LEVEL, params); | ||
}); | ||
|
||
Map<String, Object> params = Map.of("level", levelName); | ||
|
||
return this.mcpSession.sendNotification(McpSchema.METHOD_LOGGING_SET_LEVEL, params); | ||
} | ||
|
||
/** | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it be this value and undocumented? I fear it would be worthwhile to add logging with a note after how long the client gave up. Also, this will have to be revisited once reconnects are considered.