Skip to content

Commit 4ca00c2

Browse files
authored
fix(client): add error enum while deal client info (#76)
1. wrap the error type for more standardized 2. add more information in error for debug trace 3. wrap helper func for more user-friendly code Signed-off-by: jokemanfire <[email protected]>
1 parent 2c0cafd commit 4ca00c2

File tree

1 file changed

+80
-28
lines changed

1 file changed

+80
-28
lines changed

crates/rmcp/src/service/client.rs

+80-28
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,73 @@
1-
use futures::{SinkExt, StreamExt};
1+
use futures::{SinkExt, Stream, StreamExt};
2+
use thiserror::Error;
23

34
use super::*;
45
use crate::model::{
56
CallToolRequest, CallToolRequestParam, CallToolResult, CancelledNotification,
67
CancelledNotificationParam, ClientInfo, ClientMessage, ClientNotification, ClientRequest,
78
ClientResult, CompleteRequest, CompleteRequestParam, CompleteResult, GetPromptRequest,
89
GetPromptRequestParam, GetPromptResult, InitializeRequest, InitializedNotification,
9-
ListPromptsRequest, ListPromptsResult, ListResourceTemplatesRequest,
10+
JsonRpcResponse, ListPromptsRequest, ListPromptsResult, ListResourceTemplatesRequest,
1011
ListResourceTemplatesResult, ListResourcesRequest, ListResourcesResult, ListToolsRequest,
1112
ListToolsResult, PaginatedRequestParam, PaginatedRequestParamInner, ProgressNotification,
1213
ProgressNotificationParam, ReadResourceRequest, ReadResourceRequestParam, ReadResourceResult,
13-
RootsListChangedNotification, ServerInfo, ServerNotification, ServerRequest, ServerResult,
14-
SetLevelRequest, SetLevelRequestParam, SubscribeRequest, SubscribeRequestParam,
15-
UnsubscribeRequest, UnsubscribeRequestParam,
14+
RequestId, RootsListChangedNotification, ServerInfo, ServerJsonRpcMessage, ServerNotification,
15+
ServerRequest, ServerResult, SetLevelRequest, SetLevelRequestParam, SubscribeRequest,
16+
SubscribeRequestParam, UnsubscribeRequest, UnsubscribeRequestParam,
1617
};
1718

19+
/// It represents the error that may occur when serving the client.
20+
///
21+
/// if you want to handle the error, you can use `serve_client_with_ct` or `serve_client` with `Result<RunningService<RoleClient, S>, ClientError>`
22+
#[derive(Error, Debug)]
23+
pub enum ClientError {
24+
#[error("expect initialized response, but received: {0:?}")]
25+
ExpectedInitResponse(Option<ServerJsonRpcMessage>),
26+
27+
#[error("expect initialized result, but received: {0:?}")]
28+
ExpectedInitResult(Option<ServerResult>),
29+
30+
#[error("conflict initialized response id: expected {0}, got {1}")]
31+
ConflictInitResponseId(RequestId, RequestId),
32+
33+
#[error("connection closed: {0}")]
34+
ConnectionClosed(String),
35+
36+
#[error("IO error: {0}")]
37+
Io(#[from] std::io::Error),
38+
}
39+
40+
/// Helper function to get the next message from the stream
41+
async fn expect_next_message<S>(
42+
stream: &mut S,
43+
context: &str,
44+
) -> Result<ServerJsonRpcMessage, ClientError>
45+
where
46+
S: Stream<Item = ServerJsonRpcMessage> + Unpin,
47+
{
48+
stream
49+
.next()
50+
.await
51+
.ok_or_else(|| ClientError::ConnectionClosed(context.to_string()))
52+
.map_err(|e| ClientError::Io(std::io::Error::new(std::io::ErrorKind::Other, e)))
53+
}
54+
55+
/// Helper function to expect a response from the stream
56+
async fn expect_response<S>(
57+
stream: &mut S,
58+
context: &str,
59+
) -> Result<(ServerResult, RequestId), ClientError>
60+
where
61+
S: Stream<Item = ServerJsonRpcMessage> + Unpin,
62+
{
63+
let msg = expect_next_message(stream, context).await?;
64+
65+
match msg {
66+
ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => Ok((result, id)),
67+
_ => Err(ClientError::ExpectedInitResponse(Some(msg))),
68+
}
69+
}
70+
1871
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
1972
pub struct RoleClient;
2073

@@ -74,6 +127,15 @@ where
74127
let mut sink = Box::pin(sink);
75128
let mut stream = Box::pin(stream);
76129
let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
130+
131+
// Convert ClientError to std::io::Error, then to E
132+
let handle_client_error = |e: ClientError| -> E {
133+
match e {
134+
ClientError::Io(io_err) => io_err.into(),
135+
other => std::io::Error::new(std::io::ErrorKind::Other, format!("{}", other)).into(),
136+
}
137+
};
138+
77139
// service
78140
let id = id_provider.next_request_id();
79141
let init_request = InitializeRequest {
@@ -85,34 +147,24 @@ where
85147
.into_json_rpc_message(),
86148
)
87149
.await?;
88-
let (response, response_id) = stream
89-
.next()
150+
151+
let (response, response_id) = expect_response(&mut stream, "initialize response")
90152
.await
91-
.ok_or(std::io::Error::new(
92-
std::io::ErrorKind::UnexpectedEof,
93-
"expect initialize response",
94-
))?
95-
.into_message()
96-
.into_result()
97-
.ok_or(std::io::Error::new(
98-
std::io::ErrorKind::InvalidData,
99-
"expect initialize result",
100-
))?;
153+
.map_err(handle_client_error)?;
154+
101155
if id != response_id {
102-
return Err(std::io::Error::new(
103-
std::io::ErrorKind::InvalidData,
104-
"conflict initialize response id",
105-
)
106-
.into());
156+
return Err(handle_client_error(ClientError::ConflictInitResponseId(
157+
id,
158+
response_id,
159+
)));
107160
}
108-
let response = response.map_err(std::io::Error::other)?;
161+
109162
let ServerResult::InitializeResult(initialize_result) = response else {
110-
return Err(std::io::Error::new(
111-
std::io::ErrorKind::InvalidData,
112-
"expect initialize result",
113-
)
114-
.into());
163+
return Err(handle_client_error(ClientError::ExpectedInitResult(Some(
164+
response,
165+
))));
115166
};
167+
116168
// send notification
117169
let notification = ClientMessage::Notification(ClientNotification::InitializedNotification(
118170
InitializedNotification {

0 commit comments

Comments
 (0)