Skip to content

Commit a3b42c4

Browse files
authored
fix: sse endpoint build follow js's new URL(url, base) (#197)
* fix: sse endpoint build follow js's `new URL(url, base)` * fix: fix feature denpencies
1 parent 752b438 commit a3b42c4

File tree

5 files changed

+57
-35
lines changed

5 files changed

+57
-35
lines changed

crates/rmcp/Cargo.toml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,14 @@ reqwest = { version = "0.12", default-features = false, features = [
4040
"stream",
4141
], optional = true }
4242
sse-stream = { version = "0.1.4", optional = true }
43+
http = { version = "1", optional = true }
4344
url = { version = "2.4", optional = true }
4445

4546
# For tower compatibility
4647
tower-service = { version = "0.3", optional = true }
4748

4849
# for child process transport
49-
process-wrap = { version = "8.2", features = ["tokio1"], optional = true}
50+
process-wrap = { version = "8.2", features = ["tokio1"], optional = true }
5051

5152
# for ws transport
5253
# tokio-tungstenite ={ version = "0.26", optional = true }
@@ -75,18 +76,15 @@ reqwest-tls-no-provider = ["__reqwest", "reqwest?/rustls-tls-no-provider"]
7576

7677
axum = ["dep:axum"]
7778
# SSE client
78-
client-side-sse = ["dep:sse-stream"]
79+
client-side-sse = ["dep:sse-stream", "dep:http"]
7980

8081
transport-sse-client = ["client-side-sse", "transport-worker"]
8182

8283
transport-worker = ["dep:tokio-stream"]
8384

8485

8586
# Streamable HTTP client
86-
transport-streamable-http-client = [
87-
"client-side-sse",
88-
"transport-worker",
89-
]
87+
transport-streamable-http-client = ["client-side-sse", "transport-worker"]
9088

9189

9290
transport-async-rw = ["tokio/io-util", "tokio-util/codec"]
@@ -98,6 +96,7 @@ transport-child-process = [
9896
]
9997
transport-sse-server = [
10098
"transport-async-rw",
99+
"transport-worker",
101100
"axum",
102101
"dep:rand",
103102
"dep:tokio-stream",

crates/rmcp/src/transport/common/auth/sse_client.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use http::Uri;
2+
13
use crate::transport::{
24
auth::AuthClient,
35
sse_client::{SseClient, SseTransportError},
@@ -10,7 +12,7 @@ where
1012

1113
async fn post_message(
1214
&self,
13-
uri: std::sync::Arc<str>,
15+
uri: Uri,
1416
message: crate::model::ClientJsonRpcMessage,
1517
mut auth_token: Option<String>,
1618
) -> Result<(), SseTransportError<Self::Error>> {
@@ -25,7 +27,7 @@ where
2527

2628
async fn get_stream(
2729
&self,
28-
uri: std::sync::Arc<str>,
30+
uri: Uri,
2931
last_event_id: Option<String>,
3032
mut auth_token: Option<String>,
3133
) -> Result<

crates/rmcp/src/transport/common/reqwest/sse_client.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::sync::Arc;
22

33
use futures::StreamExt;
4+
use http::Uri;
45
use reqwest::header::ACCEPT;
56
use sse_stream::SseStream;
67

@@ -15,11 +16,11 @@ impl SseClient for reqwest::Client {
1516

1617
async fn post_message(
1718
&self,
18-
uri: std::sync::Arc<str>,
19+
uri: Uri,
1920
message: crate::model::ClientJsonRpcMessage,
2021
auth_token: Option<String>,
2122
) -> Result<(), SseTransportError<Self::Error>> {
22-
let mut request_builder = self.post(uri.as_ref()).json(&message);
23+
let mut request_builder = self.post(uri.to_string()).json(&message);
2324
if let Some(auth_header) = auth_token {
2425
request_builder = request_builder.bearer_auth(auth_header);
2526
}
@@ -33,15 +34,15 @@ impl SseClient for reqwest::Client {
3334

3435
async fn get_stream(
3536
&self,
36-
uri: std::sync::Arc<str>,
37+
uri: Uri,
3738
last_event_id: Option<String>,
3839
auth_token: Option<String>,
3940
) -> Result<
4041
crate::transport::common::client_side_sse::BoxedSseResponse,
4142
SseTransportError<Self::Error>,
4243
> {
4344
let mut request_builder = self
44-
.get(uri.as_ref())
45+
.get(uri.to_string())
4546
.header(ACCEPT, EVENT_STREAM_MIME_TYPE);
4647
if let Some(auth_header) = auth_token {
4748
request_builder = request_builder.bearer_auth(auth_header);
@@ -73,7 +74,7 @@ impl SseClientTransport<reqwest::Client> {
7374
SseClientTransport::start_with_client(
7475
reqwest::Client::default(),
7576
SseClientConfig {
76-
uri: uri.into(),
77+
sse_endpoint: uri.into(),
7778
..Default::default()
7879
},
7980
)

crates/rmcp/src/transport/sse_client.rs

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
use std::{pin::Pin, sync::Arc};
33

44
use futures::{StreamExt, future::BoxFuture};
5+
use http::Uri;
56
use reqwest::header::HeaderValue;
67
use sse_stream::Error as SseError;
78
use thiserror::Error;
@@ -32,6 +33,10 @@ pub enum SseTransportError<E: std::error::Error + Send + Sync + 'static> {
3233
#[cfg_attr(docsrs, doc(cfg(feature = "auth")))]
3334
#[error("Auth error: {0}")]
3435
Auth(#[from] crate::transport::auth::AuthError),
36+
#[error("Invalid uri: {0}")]
37+
InvalidUri(#[from] http::uri::InvalidUri),
38+
#[error("Invalid uri parts: {0}")]
39+
InvalidUriParts(#[from] http::uri::InvalidUriParts),
3540
}
3641

3742
impl From<reqwest::Error> for SseTransportError<reqwest::Error> {
@@ -44,21 +49,21 @@ pub trait SseClient: Clone + Send + Sync + 'static {
4449
type Error: std::error::Error + Send + Sync + 'static;
4550
fn post_message(
4651
&self,
47-
uri: Arc<str>,
52+
uri: Uri,
4853
message: ClientJsonRpcMessage,
4954
auth_token: Option<String>,
5055
) -> impl Future<Output = Result<(), SseTransportError<Self::Error>>> + Send + '_;
5156
fn get_stream(
5257
&self,
53-
uri: Arc<str>,
58+
uri: Uri,
5459
last_event_id: Option<String>,
5560
auth_token: Option<String>,
5661
) -> impl Future<Output = Result<BoxedSseResponse, SseTransportError<Self::Error>>> + Send + '_;
5762
}
5863

5964
struct SseClientReconnect<C> {
6065
pub client: C,
61-
pub uri: Arc<str>,
66+
pub uri: Uri,
6267
}
6368

6469
impl<C: SseClient> SseStreamReconnect for SseClientReconnect<C> {
@@ -75,7 +80,7 @@ type ServerMessageStream<C> = Pin<Box<SseAutoReconnectStream<SseClientReconnect<
7580
pub struct SseClientTransport<C: SseClient> {
7681
client: C,
7782
config: SseClientConfig,
78-
post_uri: Arc<str>,
83+
message_endpoint: Uri,
7984
stream: Option<ServerMessageStream<C>>,
8085
}
8186

@@ -89,7 +94,7 @@ impl<C: SseClient> Transport<RoleClient> for SseClientTransport<C> {
8994
item: crate::service::TxJsonRpcMessage<RoleClient>,
9095
) -> impl Future<Output = Result<(), Self::Error>> + Send + 'static {
9196
let client = self.client.clone();
92-
let uri = self.post_uri.clone();
97+
let uri = self.message_endpoint.clone();
9398
async move { client.post_message(uri, item, None).await }
9499
}
95100
async fn close(&mut self) -> Result<(), Self::Error> {
@@ -112,9 +117,11 @@ impl<C: SseClient> SseClientTransport<C> {
112117
client: C,
113118
config: SseClientConfig,
114119
) -> Result<Self, SseTransportError<C::Error>> {
115-
let mut sse_stream = client.get_stream(config.uri.clone(), None, None).await?;
116-
let endpoint = if let Some(endpoint) = config.use_endpoint.clone() {
117-
endpoint
120+
let sse_endpoint = config.sse_endpoint.as_ref().parse::<http::Uri>()?;
121+
122+
let mut sse_stream = client.get_stream(sse_endpoint.clone(), None, None).await?;
123+
let message_endpoint = if let Some(endpoint) = config.use_message_endpoint.clone() {
124+
endpoint.parse::<http::Uri>()?
118125
} else {
119126
// wait the endpoint event
120127
loop {
@@ -125,46 +132,59 @@ impl<C: SseClient> SseClientTransport<C> {
125132
let Some("endpoint") = sse.event.as_deref() else {
126133
continue;
127134
};
128-
break sse.data.unwrap_or_default();
135+
let sse_endpoint = sse.data.unwrap_or_default();
136+
break sse_endpoint.parse::<http::Uri>()?;
129137
}
130138
};
131-
let post_uri: Arc<str> = format!(
132-
"{}/{}",
133-
config.uri.trim_end_matches("/"),
134-
endpoint.trim_start_matches("/")
135-
)
136-
.into();
139+
140+
// sse: <authority><sse_pq> -> <authority><message_pq>
141+
let message_endpoint = {
142+
let mut sse_endpoint_parts = sse_endpoint.clone().into_parts();
143+
sse_endpoint_parts.path_and_query = message_endpoint.into_parts().path_and_query;
144+
Uri::from_parts(sse_endpoint_parts)?
145+
};
137146
let stream = Box::pin(SseAutoReconnectStream::new(
138147
sse_stream,
139148
SseClientReconnect {
140149
client: client.clone(),
141-
uri: config.uri.clone(),
150+
uri: sse_endpoint.clone(),
142151
},
143152
config.retry_policy.clone(),
144153
));
145154
Ok(Self {
146155
client,
147156
config,
148-
post_uri,
157+
message_endpoint,
149158
stream: Some(stream),
150159
})
151160
}
152161
}
153162

154163
#[derive(Debug, Clone)]
155164
pub struct SseClientConfig {
156-
pub uri: Arc<str>,
165+
/// client sse endpoint
166+
///
167+
/// # How this client resolve the message endpoint
168+
/// if sse_endpoint has this format: `<schema><authority?><sse_pq>`,
169+
/// then the message endpoint will be `<schema><authority?><message_pq>`.
170+
///
171+
/// For example, if you config the sse_endpoint as `http://example.com/some_path/sse`,
172+
/// and the server send the message endpoint event as `message?session_id=123`,
173+
/// then the message endpoint will be `http://example.com/message`.
174+
///
175+
/// This follow the rules of JavaScript's [`new URL(url, base)`](https://developer.mozilla.org/zh-CN/docs/Web/API/URL/URL)
176+
pub sse_endpoint: Arc<str>,
157177
pub retry_policy: Arc<dyn SseRetryPolicy>,
158178
/// if this is settled, the client will use this endpoint to send message and skip get the endpoint event
159-
pub use_endpoint: Option<String>,
179+
pub use_message_endpoint: Option<String>,
160180
}
161181

162182
impl Default for SseClientConfig {
163183
fn default() -> Self {
164184
Self {
165-
uri: "".into(),
185+
sse_endpoint: "".into(),
166186
retry_policy: Arc::new(super::common::client_side_sse::FixedInterval::default()),
167-
use_endpoint: None,
187+
use_message_endpoint: None,
168188
}
169189
}
170190
}

examples/clients/src/oauth_client.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ async fn main() -> Result<()> {
147147
let transport = SseClientTransport::start_with_client(
148148
client,
149149
SseClientConfig {
150-
uri: MCP_SSE_URL.into(),
150+
sse_endpoint: MCP_SSE_URL.into(),
151151
..Default::default()
152152
},
153153
)

0 commit comments

Comments
 (0)