2
2
use std:: { pin:: Pin , sync:: Arc } ;
3
3
4
4
use futures:: { StreamExt , future:: BoxFuture } ;
5
+ use http:: Uri ;
5
6
use reqwest:: header:: HeaderValue ;
6
7
use sse_stream:: Error as SseError ;
7
8
use thiserror:: Error ;
@@ -32,6 +33,10 @@ pub enum SseTransportError<E: std::error::Error + Send + Sync + 'static> {
32
33
#[ cfg_attr( docsrs, doc( cfg( feature = "auth" ) ) ) ]
33
34
#[ error( "Auth error: {0}" ) ]
34
35
Auth ( #[ from] crate :: transport:: auth:: AuthError ) ,
36
+ #[ error( "Invalid uri: {0}" ) ]
37
+ InvalidUri ( #[ from] http:: uri:: InvalidUri ) ,
38
+ #[ error( "Invalid uri parts: {0}" ) ]
39
+ InvalidUriParts ( #[ from] http:: uri:: InvalidUriParts ) ,
35
40
}
36
41
37
42
impl From < reqwest:: Error > for SseTransportError < reqwest:: Error > {
@@ -44,21 +49,21 @@ pub trait SseClient: Clone + Send + Sync + 'static {
44
49
type Error : std:: error:: Error + Send + Sync + ' static ;
45
50
fn post_message (
46
51
& self ,
47
- uri : Arc < str > ,
52
+ uri : Uri ,
48
53
message : ClientJsonRpcMessage ,
49
54
auth_token : Option < String > ,
50
55
) -> impl Future < Output = Result < ( ) , SseTransportError < Self :: Error > > > + Send + ' _ ;
51
56
fn get_stream (
52
57
& self ,
53
- uri : Arc < str > ,
58
+ uri : Uri ,
54
59
last_event_id : Option < String > ,
55
60
auth_token : Option < String > ,
56
61
) -> impl Future < Output = Result < BoxedSseResponse , SseTransportError < Self :: Error > > > + Send + ' _ ;
57
62
}
58
63
59
64
struct SseClientReconnect < C > {
60
65
pub client : C ,
61
- pub uri : Arc < str > ,
66
+ pub uri : Uri ,
62
67
}
63
68
64
69
impl < C : SseClient > SseStreamReconnect for SseClientReconnect < C > {
@@ -75,7 +80,7 @@ type ServerMessageStream<C> = Pin<Box<SseAutoReconnectStream<SseClientReconnect<
75
80
pub struct SseClientTransport < C : SseClient > {
76
81
client : C ,
77
82
config : SseClientConfig ,
78
- post_uri : Arc < str > ,
83
+ message_endpoint : Uri ,
79
84
stream : Option < ServerMessageStream < C > > ,
80
85
}
81
86
@@ -89,7 +94,7 @@ impl<C: SseClient> Transport<RoleClient> for SseClientTransport<C> {
89
94
item : crate :: service:: TxJsonRpcMessage < RoleClient > ,
90
95
) -> impl Future < Output = Result < ( ) , Self :: Error > > + Send + ' static {
91
96
let client = self . client . clone ( ) ;
92
- let uri = self . post_uri . clone ( ) ;
97
+ let uri = self . message_endpoint . clone ( ) ;
93
98
async move { client. post_message ( uri, item, None ) . await }
94
99
}
95
100
async fn close ( & mut self ) -> Result < ( ) , Self :: Error > {
@@ -112,9 +117,11 @@ impl<C: SseClient> SseClientTransport<C> {
112
117
client : C ,
113
118
config : SseClientConfig ,
114
119
) -> Result < Self , SseTransportError < C :: Error > > {
115
- let mut sse_stream = client. get_stream ( config. uri . clone ( ) , None , None ) . await ?;
116
- let endpoint = if let Some ( endpoint) = config. use_endpoint . clone ( ) {
117
- endpoint
120
+ let sse_endpoint = config. sse_endpoint . as_ref ( ) . parse :: < http:: Uri > ( ) ?;
121
+
122
+ let mut sse_stream = client. get_stream ( sse_endpoint. clone ( ) , None , None ) . await ?;
123
+ let message_endpoint = if let Some ( endpoint) = config. use_message_endpoint . clone ( ) {
124
+ endpoint. parse :: < http:: Uri > ( ) ?
118
125
} else {
119
126
// wait the endpoint event
120
127
loop {
@@ -125,46 +132,59 @@ impl<C: SseClient> SseClientTransport<C> {
125
132
let Some ( "endpoint" ) = sse. event . as_deref ( ) else {
126
133
continue ;
127
134
} ;
128
- break sse. data . unwrap_or_default ( ) ;
135
+ let sse_endpoint = sse. data . unwrap_or_default ( ) ;
136
+ break sse_endpoint. parse :: < http:: Uri > ( ) ?;
129
137
}
130
138
} ;
131
- let post_uri: Arc < str > = format ! (
132
- "{}/{}" ,
133
- config. uri. trim_end_matches( "/" ) ,
134
- endpoint. trim_start_matches( "/" )
135
- )
136
- . into ( ) ;
139
+
140
+ // sse: <authority><sse_pq> -> <authority><message_pq>
141
+ let message_endpoint = {
142
+ let mut sse_endpoint_parts = sse_endpoint. clone ( ) . into_parts ( ) ;
143
+ sse_endpoint_parts. path_and_query = message_endpoint. into_parts ( ) . path_and_query ;
144
+ Uri :: from_parts ( sse_endpoint_parts) ?
145
+ } ;
137
146
let stream = Box :: pin ( SseAutoReconnectStream :: new (
138
147
sse_stream,
139
148
SseClientReconnect {
140
149
client : client. clone ( ) ,
141
- uri : config . uri . clone ( ) ,
150
+ uri : sse_endpoint . clone ( ) ,
142
151
} ,
143
152
config. retry_policy . clone ( ) ,
144
153
) ) ;
145
154
Ok ( Self {
146
155
client,
147
156
config,
148
- post_uri ,
157
+ message_endpoint ,
149
158
stream : Some ( stream) ,
150
159
} )
151
160
}
152
161
}
153
162
154
163
#[ derive( Debug , Clone ) ]
155
164
pub struct SseClientConfig {
156
- pub uri : Arc < str > ,
165
+ /// client sse endpoint
166
+ ///
167
+ /// # How this client resolve the message endpoint
168
+ /// if sse_endpoint has this format: `<schema><authority?><sse_pq>`,
169
+ /// then the message endpoint will be `<schema><authority?><message_pq>`.
170
+ ///
171
+ /// For example, if you config the sse_endpoint as `http://example.com/some_path/sse`,
172
+ /// and the server send the message endpoint event as `message?session_id=123`,
173
+ /// then the message endpoint will be `http://example.com/message`.
174
+ ///
175
+ /// This follow the rules of JavaScript's [`new URL(url, base)`](https://developer.mozilla.org/zh-CN/docs/Web/API/URL/URL)
176
+ pub sse_endpoint : Arc < str > ,
157
177
pub retry_policy : Arc < dyn SseRetryPolicy > ,
158
178
/// if this is settled, the client will use this endpoint to send message and skip get the endpoint event
159
- pub use_endpoint : Option < String > ,
179
+ pub use_message_endpoint : Option < String > ,
160
180
}
161
181
162
182
impl Default for SseClientConfig {
163
183
fn default ( ) -> Self {
164
184
Self {
165
- uri : "" . into ( ) ,
185
+ sse_endpoint : "" . into ( ) ,
166
186
retry_policy : Arc :: new ( super :: common:: client_side_sse:: FixedInterval :: default ( ) ) ,
167
- use_endpoint : None ,
187
+ use_message_endpoint : None ,
168
188
}
169
189
}
170
190
}
0 commit comments