Skip to content

Commit 076dc2c

Browse files
authored
feat: better http server support (#199)
* feat: better http server support 1. allow user get extensions in tool call. 2. allow user serve sse service without initialization. * fix: fix document test
1 parent c1c4c9a commit 076dc2c

File tree

13 files changed

+95
-96
lines changed

13 files changed

+95
-96
lines changed

crates/rmcp/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ tracing-subscriber = { version = "0.3", features = [
129129
async-trait = "0.1"
130130
[[test]]
131131
name = "test_tool_macros"
132-
required-features = ["server"]
132+
required-features = ["server", "client"]
133133
path = "tests/test_tool_macros.rs"
134134

135135
[[test]]

crates/rmcp/src/handler/client.rs

Lines changed: 4 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::{
22
error::Error as McpError,
33
model::*,
4-
service::{Peer, RequestContext, RoleClient, Service, ServiceRole},
4+
service::{RequestContext, RoleClient, Service, ServiceRole},
55
};
66

77
impl<H: ClientHandler> Service<RoleClient> for H {
@@ -118,47 +118,16 @@ pub trait ClientHandler: Sized + Send + Sync + 'static {
118118
std::future::ready(())
119119
}
120120

121-
fn get_peer(&self) -> Option<Peer<RoleClient>>;
122-
123-
fn set_peer(&mut self, peer: Peer<RoleClient>);
124-
125121
fn get_info(&self) -> ClientInfo {
126122
ClientInfo::default()
127123
}
128124
}
129125

130-
/// Do nothing, just store the peer.
131-
impl ClientHandler for Option<Peer<RoleClient>> {
132-
fn get_peer(&self) -> Option<Peer<RoleClient>> {
133-
self.clone()
134-
}
135-
136-
fn set_peer(&mut self, peer: Peer<RoleClient>) {
137-
*self = Some(peer);
138-
}
139-
}
140-
141-
/// Do nothing, even store the peer.
142-
impl ClientHandler for () {
143-
fn get_peer(&self) -> Option<Peer<RoleClient>> {
144-
None
145-
}
146-
147-
fn set_peer(&mut self, peer: Peer<RoleClient>) {
148-
drop(peer);
149-
}
150-
}
126+
/// Do nothing, with default client info.
127+
impl ClientHandler for () {}
151128

152-
/// Do nothing, even store the peer.
129+
/// Do nothing, with a specific client info.
153130
impl ClientHandler for ClientInfo {
154-
fn get_peer(&self) -> Option<Peer<RoleClient>> {
155-
None
156-
}
157-
158-
fn set_peer(&mut self, peer: Peer<RoleClient>) {
159-
drop(peer);
160-
}
161-
162131
fn get_info(&self) -> ClientInfo {
163132
self.clone()
164133
}

crates/rmcp/src/handler/server.rs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::{
22
error::Error as McpError,
33
model::*,
4-
service::{Peer, RequestContext, RoleServer, Service, ServiceRole},
4+
service::{RequestContext, RoleServer, Service, ServiceRole},
55
};
66

77
mod resource;
@@ -108,6 +108,10 @@ pub trait ServerHandler: Sized + Send + Sync + 'static {
108108
request: InitializeRequestParam,
109109
context: RequestContext<RoleServer>,
110110
) -> impl Future<Output = Result<InitializeResult, McpError>> + Send + '_ {
111+
if context.peer.peer_info().is_none() {
112+
context.peer.set_peer_info(request);
113+
}
114+
let info = self.get_info();
111115
std::future::ready(Ok(self.get_info()))
112116
}
113117
fn complete(
@@ -210,14 +214,6 @@ pub trait ServerHandler: Sized + Send + Sync + 'static {
210214
std::future::ready(())
211215
}
212216

213-
fn get_peer(&self) -> Option<Peer<RoleServer>> {
214-
None
215-
}
216-
217-
fn set_peer(&mut self, peer: Peer<RoleServer>) {
218-
drop(peer);
219-
}
220-
221217
fn get_info(&self) -> ServerInfo {
222218
ServerInfo::default()
223219
}

crates/rmcp/src/handler/server/tool.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ impl<'service, S> ToolCallContext<'service, S> {
8686
pub fn name(&self) -> &str {
8787
&self.name
8888
}
89+
pub fn request_context(&self) -> &RequestContext<RoleServer> {
90+
&self.request_context
91+
}
8992
}
9093

9194
pub trait FromToolCallContextPart<'a, S>: Sized {
@@ -284,6 +287,39 @@ impl<'a, S> FromToolCallContextPart<'a, S> for JsonObject {
284287
}
285288
}
286289

290+
impl<'a, S> FromToolCallContextPart<'a, S> for crate::model::Extensions {
291+
fn from_tool_call_context_part(
292+
context: ToolCallContext<'a, S>,
293+
) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> {
294+
let extensions = context.request_context.extensions.clone();
295+
Ok((extensions, context))
296+
}
297+
}
298+
299+
pub struct Extension<T>(pub T);
300+
301+
impl<'a, S, T> FromToolCallContextPart<'a, S> for Extension<T>
302+
where
303+
T: Send + Sync + 'static + Clone,
304+
{
305+
fn from_tool_call_context_part(
306+
context: ToolCallContext<'a, S>,
307+
) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> {
308+
let extension = context
309+
.request_context
310+
.extensions
311+
.get::<T>()
312+
.cloned()
313+
.ok_or_else(|| {
314+
crate::Error::invalid_params(
315+
format!("missing extension {}", std::any::type_name::<T>()),
316+
None,
317+
)
318+
})?;
319+
Ok((Extension(extension), context))
320+
}
321+
}
322+
287323
impl<'s, S> ToolCallContext<'s, S> {
288324
pub fn invoke<H, A>(self, h: H) -> H::Fut
289325
where

crates/rmcp/src/service.rs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ pub struct Peer<R: ServiceRole> {
303303
tx: mpsc::Sender<PeerSinkMessage<R>>,
304304
request_id_provider: Arc<dyn RequestIdProvider>,
305305
progress_token_provider: Arc<dyn ProgressTokenProvider>,
306-
info: Arc<R::PeerInfo>,
306+
info: Arc<tokio::sync::OnceCell<R::PeerInfo>>,
307307
}
308308

309309
impl<R: ServiceRole> std::fmt::Debug for Peer<R> {
@@ -333,15 +333,15 @@ impl<R: ServiceRole> Peer<R> {
333333
const CLIENT_CHANNEL_BUFFER_SIZE: usize = 1024;
334334
pub(crate) fn new(
335335
request_id_provider: Arc<dyn RequestIdProvider>,
336-
peer_info: R::PeerInfo,
336+
peer_info: Option<R::PeerInfo>,
337337
) -> (Peer<R>, ProxyOutbound<R>) {
338338
let (tx, rx) = mpsc::channel(Self::CLIENT_CHANNEL_BUFFER_SIZE);
339339
(
340340
Self {
341341
tx,
342342
request_id_provider,
343343
progress_token_provider: Arc::new(AtomicU32ProgressTokenProvider::default()),
344-
info: peer_info.into(),
344+
info: Arc::new(tokio::sync::OnceCell::new_with(peer_info)),
345345
},
346346
rx,
347347
)
@@ -402,8 +402,16 @@ impl<R: ServiceRole> Peer<R> {
402402
peer: self.clone(),
403403
})
404404
}
405-
pub fn peer_info(&self) -> &R::PeerInfo {
406-
&self.info
405+
pub fn peer_info(&self) -> Option<&R::PeerInfo> {
406+
self.info.get()
407+
}
408+
409+
pub fn set_peer_info(&self, info: R::PeerInfo) {
410+
if self.info.initialized() {
411+
tracing::warn!("trying to set peer info, which is already initialized");
412+
} else {
413+
let _ = self.info.set(info);
414+
}
407415
}
408416

409417
pub fn is_transport_closed(&self) -> bool {
@@ -469,7 +477,7 @@ pub struct RequestContext<R: ServiceRole> {
469477
pub async fn serve_directly<R, S, T, E, A>(
470478
service: S,
471479
transport: T,
472-
peer_info: R::PeerInfo,
480+
peer_info: Option<R::PeerInfo>,
473481
) -> RunningService<R, S>
474482
where
475483
R: ServiceRole,
@@ -484,7 +492,7 @@ where
484492
pub async fn serve_directly_with_ct<R, S, T, E, A>(
485493
service: S,
486494
transport: T,
487-
peer_info: R::PeerInfo,
495+
peer_info: Option<R::PeerInfo>,
488496
ct: CancellationToken,
489497
) -> RunningService<R, S>
490498
where

crates/rmcp/src/service/client.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ where
174174
error,
175175
context: "send initialized notification".into(),
176176
})?;
177-
let (peer, peer_rx) = Peer::new(id_provider, initialize_result);
177+
let (peer, peer_rx) = Peer::new(id_provider, Some(initialize_result));
178178
Ok(serve_inner(service, transport, peer, peer_rx, ct).await)
179179
}
180180

crates/rmcp/src/service/server.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ where
156156
ClientJsonRpcMessage::request(request, id),
157157
)));
158158
};
159-
let (peer, peer_rx) = Peer::new(id_provider, peer_info.params.clone());
159+
let (peer, peer_rx) = Peer::new(id_provider, Some(peer_info.params.clone()));
160160
let context = RequestContext {
161161
ct: ct.child_token(),
162162
id: id.clone(),

crates/rmcp/src/transport.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
//!
5858
//! // create transport from std io
5959
//! async fn io() -> Result<(), Box<dyn std::error::Error>> {
60-
//! let client = None.serve((tokio::io::stdin(), tokio::io::stdout())).await?;
60+
//! let client = ().serve((tokio::io::stdin(), tokio::io::stdout())).await?;
6161
//! let tools = client.peer().list_tools(Default::default()).await?;
6262
//! println!("{:?}", tools);
6363
//! Ok(())

crates/rmcp/src/transport/sse_server.rs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use tracing::Instrument;
1818
use crate::{
1919
RoleServer, Service,
2020
model::ClientJsonRpcMessage,
21-
service::{RxJsonRpcMessage, TxJsonRpcMessage},
21+
service::{RxJsonRpcMessage, TxJsonRpcMessage, serve_directly_with_ct},
2222
transport::common::axum::{DEFAULT_AUTO_PING_INTERVAL, SessionId, session_id},
2323
};
2424

@@ -67,7 +67,7 @@ async fn post_event_handler(
6767
parts: Parts,
6868
Json(mut message): Json<ClientJsonRpcMessage>,
6969
) -> Result<StatusCode, StatusCode> {
70-
tracing::debug!(session_id, ?message, "new client message");
70+
tracing::debug!(session_id, ?parts, ?message, "new client message");
7171
let tx = {
7272
let rg = app.txs.read().await;
7373
rg.get(session_id.as_str())
@@ -84,9 +84,10 @@ async fn post_event_handler(
8484

8585
async fn sse_handler(
8686
State(app): State<App>,
87+
parts: Parts,
8788
) -> Result<Sse<impl Stream<Item = Result<Event, io::Error>>>, Response<String>> {
8889
let session = session_id();
89-
tracing::info!(%session, "sse connection");
90+
tracing::info!(%session, ?parts, "sse connection");
9091
use tokio_stream::{StreamExt, wrappers::ReceiverStream};
9192
use tokio_util::sync::PollSender;
9293
let (from_client_tx, from_client_rx) = tokio::sync::mpsc::channel(64);
@@ -300,6 +301,27 @@ impl SseServer {
300301
ct
301302
}
302303

304+
/// This allows you to skip the initialization steps for incoming request.
305+
pub fn with_service_directly<S, F>(mut self, service_provider: F) -> CancellationToken
306+
where
307+
S: Service<RoleServer>,
308+
F: Fn() -> S + Send + 'static,
309+
{
310+
let ct = self.config.ct.clone();
311+
tokio::spawn(async move {
312+
while let Some(transport) = self.next_transport().await {
313+
let service = service_provider();
314+
let ct = self.config.ct.child_token();
315+
tokio::spawn(async move {
316+
let server = serve_directly_with_ct(service, transport, None, ct).await;
317+
server.waiting().await?;
318+
tokio::io::Result::Ok(())
319+
});
320+
}
321+
});
322+
ct
323+
}
324+
303325
pub fn cancel(&self) {
304326
self.config.ct.cancel();
305327
}

crates/rmcp/tests/common/handlers.rs

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,14 @@ use std::{
44
};
55

66
use rmcp::{
7-
ClientHandler, Error as McpError, RoleClient, RoleServer, ServerHandler,
8-
model::*,
9-
service::{Peer, RequestContext},
7+
ClientHandler, Error as McpError, RoleClient, RoleServer, ServerHandler, model::*,
8+
service::RequestContext,
109
};
1110
use serde_json::json;
1211
use tokio::sync::Notify;
1312

1413
#[derive(Clone)]
1514
pub struct TestClientHandler {
16-
pub peer: Option<Peer<RoleClient>>,
1715
pub honor_this_server: bool,
1816
pub honor_all_servers: bool,
1917
pub receive_signal: Arc<Notify>,
@@ -24,7 +22,6 @@ impl TestClientHandler {
2422
#[allow(dead_code)]
2523
pub fn new(honor_this_server: bool, honor_all_servers: bool) -> Self {
2624
Self {
27-
peer: None,
2825
honor_this_server,
2926
honor_all_servers,
3027
receive_signal: Arc::new(Notify::new()),
@@ -40,7 +37,6 @@ impl TestClientHandler {
4037
received_messages: Arc<Mutex<Vec<LoggingMessageNotificationParam>>>,
4138
) -> Self {
4239
Self {
43-
peer: None,
4440
honor_this_server,
4541
honor_all_servers,
4642
receive_signal,
@@ -50,14 +46,6 @@ impl TestClientHandler {
5046
}
5147

5248
impl ClientHandler for TestClientHandler {
53-
fn get_peer(&self) -> Option<Peer<RoleClient>> {
54-
self.peer.clone()
55-
}
56-
57-
fn set_peer(&mut self, peer: Peer<RoleClient>) {
58-
self.peer = Some(peer);
59-
}
60-
6149
async fn create_message(
6250
&self,
6351
params: CreateMessageRequestParam,

crates/rmcp/tests/test_notification.rs

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::sync::Arc;
22

33
use rmcp::{
4-
ClientHandler, Peer, RoleClient, ServerHandler, ServiceExt,
4+
ClientHandler, ServerHandler, ServiceExt,
55
model::{
66
ResourceUpdatedNotificationParam, ServerCapabilities, ServerInfo, SubscribeRequestParam,
77
},
@@ -49,7 +49,6 @@ impl ServerHandler for Server {
4949

5050
pub struct Client {
5151
receive_signal: Arc<Notify>,
52-
peer: Option<Peer<RoleClient>>,
5352
}
5453

5554
impl ClientHandler for Client {
@@ -58,14 +57,6 @@ impl ClientHandler for Client {
5857
tracing::info!("Resource updated: {}", uri);
5958
self.receive_signal.notify_one();
6059
}
61-
62-
fn set_peer(&mut self, peer: Peer<RoleClient>) {
63-
self.peer.replace(peer);
64-
}
65-
66-
fn get_peer(&self) -> Option<Peer<RoleClient>> {
67-
self.peer.clone()
68-
}
6960
}
7061

7162
#[tokio::test]
@@ -85,7 +76,6 @@ async fn test_server_notification() -> anyhow::Result<()> {
8576
});
8677
let receive_signal = Arc::new(Notify::new());
8778
let client = Client {
88-
peer: Default::default(),
8979
receive_signal: receive_signal.clone(),
9080
}
9181
.serve(client_transport)

0 commit comments

Comments
 (0)