1
- use futures:: { SinkExt , StreamExt } ;
1
+ use futures:: { SinkExt , Stream , StreamExt } ;
2
+ use thiserror:: Error ;
2
3
3
4
use super :: * ;
4
5
use crate :: model:: {
5
6
CallToolRequest , CallToolRequestParam , CallToolResult , CancelledNotification ,
6
7
CancelledNotificationParam , ClientInfo , ClientMessage , ClientNotification , ClientRequest ,
7
8
ClientResult , CompleteRequest , CompleteRequestParam , CompleteResult , GetPromptRequest ,
8
9
GetPromptRequestParam , GetPromptResult , InitializeRequest , InitializedNotification ,
9
- ListPromptsRequest , ListPromptsResult , ListResourceTemplatesRequest ,
10
+ JsonRpcResponse , ListPromptsRequest , ListPromptsResult , ListResourceTemplatesRequest ,
10
11
ListResourceTemplatesResult , ListResourcesRequest , ListResourcesResult , ListToolsRequest ,
11
12
ListToolsResult , PaginatedRequestParam , PaginatedRequestParamInner , ProgressNotification ,
12
13
ProgressNotificationParam , ReadResourceRequest , ReadResourceRequestParam , ReadResourceResult ,
13
- RootsListChangedNotification , ServerInfo , ServerNotification , ServerRequest , ServerResult ,
14
- SetLevelRequest , SetLevelRequestParam , SubscribeRequest , SubscribeRequestParam ,
15
- UnsubscribeRequest , UnsubscribeRequestParam ,
14
+ RequestId , RootsListChangedNotification , ServerInfo , ServerJsonRpcMessage , ServerNotification ,
15
+ ServerRequest , ServerResult , SetLevelRequest , SetLevelRequestParam , SubscribeRequest ,
16
+ SubscribeRequestParam , UnsubscribeRequest , UnsubscribeRequestParam ,
16
17
} ;
17
18
19
+ /// It represents the error that may occur when serving the client.
20
+ ///
21
+ /// if you want to handle the error, you can use `serve_client_with_ct` or `serve_client` with `Result<RunningService<RoleClient, S>, ClientError>`
22
+ #[ derive( Error , Debug ) ]
23
+ pub enum ClientError {
24
+ #[ error( "expect initialized response, but received: {0:?}" ) ]
25
+ ExpectedInitResponse ( Option < ServerJsonRpcMessage > ) ,
26
+
27
+ #[ error( "expect initialized result, but received: {0:?}" ) ]
28
+ ExpectedInitResult ( Option < ServerResult > ) ,
29
+
30
+ #[ error( "conflict initialized response id: expected {0}, got {1}" ) ]
31
+ ConflictInitResponseId ( RequestId , RequestId ) ,
32
+
33
+ #[ error( "connection closed: {0}" ) ]
34
+ ConnectionClosed ( String ) ,
35
+
36
+ #[ error( "IO error: {0}" ) ]
37
+ Io ( #[ from] std:: io:: Error ) ,
38
+ }
39
+
40
+ /// Helper function to get the next message from the stream
41
+ async fn expect_next_message < S > (
42
+ stream : & mut S ,
43
+ context : & str ,
44
+ ) -> Result < ServerJsonRpcMessage , ClientError >
45
+ where
46
+ S : Stream < Item = ServerJsonRpcMessage > + Unpin ,
47
+ {
48
+ stream
49
+ . next ( )
50
+ . await
51
+ . ok_or_else ( || ClientError :: ConnectionClosed ( context. to_string ( ) ) )
52
+ . map_err ( |e| ClientError :: Io ( std:: io:: Error :: new ( std:: io:: ErrorKind :: Other , e) ) )
53
+ }
54
+
55
+ /// Helper function to expect a response from the stream
56
+ async fn expect_response < S > (
57
+ stream : & mut S ,
58
+ context : & str ,
59
+ ) -> Result < ( ServerResult , RequestId ) , ClientError >
60
+ where
61
+ S : Stream < Item = ServerJsonRpcMessage > + Unpin ,
62
+ {
63
+ let msg = expect_next_message ( stream, context) . await ?;
64
+
65
+ match msg {
66
+ ServerJsonRpcMessage :: Response ( JsonRpcResponse { id, result, .. } ) => Ok ( ( result, id) ) ,
67
+ _ => Err ( ClientError :: ExpectedInitResponse ( Some ( msg) ) ) ,
68
+ }
69
+ }
70
+
18
71
#[ derive( Debug , Clone , Copy , Default , PartialEq , Eq ) ]
19
72
pub struct RoleClient ;
20
73
@@ -74,6 +127,15 @@ where
74
127
let mut sink = Box :: pin ( sink) ;
75
128
let mut stream = Box :: pin ( stream) ;
76
129
let id_provider = <Arc < AtomicU32RequestIdProvider > >:: default ( ) ;
130
+
131
+ // Convert ClientError to std::io::Error, then to E
132
+ let handle_client_error = |e : ClientError | -> E {
133
+ match e {
134
+ ClientError :: Io ( io_err) => io_err. into ( ) ,
135
+ other => std:: io:: Error :: new ( std:: io:: ErrorKind :: Other , format ! ( "{}" , other) ) . into ( ) ,
136
+ }
137
+ } ;
138
+
77
139
// service
78
140
let id = id_provider. next_request_id ( ) ;
79
141
let init_request = InitializeRequest {
@@ -85,34 +147,24 @@ where
85
147
. into_json_rpc_message ( ) ,
86
148
)
87
149
. await ?;
88
- let ( response , response_id ) = stream
89
- . next ( )
150
+
151
+ let ( response , response_id ) = expect_response ( & mut stream , "initialize response" )
90
152
. await
91
- . ok_or ( std:: io:: Error :: new (
92
- std:: io:: ErrorKind :: UnexpectedEof ,
93
- "expect initialize response" ,
94
- ) ) ?
95
- . into_message ( )
96
- . into_result ( )
97
- . ok_or ( std:: io:: Error :: new (
98
- std:: io:: ErrorKind :: InvalidData ,
99
- "expect initialize result" ,
100
- ) ) ?;
153
+ . map_err ( handle_client_error) ?;
154
+
101
155
if id != response_id {
102
- return Err ( std:: io:: Error :: new (
103
- std:: io:: ErrorKind :: InvalidData ,
104
- "conflict initialize response id" ,
105
- )
106
- . into ( ) ) ;
156
+ return Err ( handle_client_error ( ClientError :: ConflictInitResponseId (
157
+ id,
158
+ response_id,
159
+ ) ) ) ;
107
160
}
108
- let response = response . map_err ( std :: io :: Error :: other ) ? ;
161
+
109
162
let ServerResult :: InitializeResult ( initialize_result) = response else {
110
- return Err ( std:: io:: Error :: new (
111
- std:: io:: ErrorKind :: InvalidData ,
112
- "expect initialize result" ,
113
- )
114
- . into ( ) ) ;
163
+ return Err ( handle_client_error ( ClientError :: ExpectedInitResult ( Some (
164
+ response,
165
+ ) ) ) ) ;
115
166
} ;
167
+
116
168
// send notification
117
169
let notification = ClientMessage :: Notification ( ClientNotification :: InitializedNotification (
118
170
InitializedNotification {
0 commit comments