From ef99bafe359fbd0425f57001455325abcfe3b9c1 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Tue, 20 May 2025 05:24:28 +0800 Subject: [PATCH 1/3] remove lifetime marker --- crates/rmcp/src/handler/server/tool.rs | 143 +++++++++++-------------- 1 file changed, 63 insertions(+), 80 deletions(-) diff --git a/crates/rmcp/src/handler/server/tool.rs b/crates/rmcp/src/handler/server/tool.rs index 30d8872..aa89c2c 100644 --- a/crates/rmcp/src/handler/server/tool.rs +++ b/crates/rmcp/src/handler/server/tool.rs @@ -63,16 +63,16 @@ pub fn parse_json_object(input: JsonObject) -> Result { +pub struct ToolCallContext { request_context: RequestContext, - service: &'service S, + service: Arc, name: Cow<'static, str>, arguments: Option, } -impl<'service, S> ToolCallContext<'service, S> { +impl ToolCallContext { pub fn new( - service: &'service S, + service: Arc, CallToolRequestParam { name, arguments }: CallToolRequestParam, request_context: RequestContext, ) -> Self { @@ -91,10 +91,8 @@ impl<'service, S> ToolCallContext<'service, S> { } } -pub trait FromToolCallContextPart<'a, S>: Sized { - fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error>; +pub trait FromToolCallContextPart: Sized { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result; } pub trait IntoCallToolResult { @@ -162,12 +160,12 @@ impl IntoCallToolResult for Result { } } -pub trait CallToolHandler<'a, S, A> { - type Fut: Future> + Send + 'a; - fn call(self, context: ToolCallContext<'a, S>) -> Self::Fut; +pub trait CallToolHandler { + type Fut: Future> + Send; + fn call(self, context: ToolCallContext) -> Self::Fut; } -pub type DynCallToolHandler = dyn Fn(ToolCallContext<'_, S>) -> BoxFuture<'_, Result> +pub type DynCallToolHandler = dyn Fn(ToolCallContext) -> BoxFuture<'static, Result> + Send + Sync; /// Parameter Extractor @@ -189,51 +187,32 @@ impl JsonSchema for Parameters

{ } } -/// Callee Extractor -pub struct Callee<'a, S>(pub &'a S); - -impl<'a, S> FromToolCallContextPart<'a, S> for CancellationToken { - fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { - Ok((context.request_context.ct.clone(), context)) - } -} - -impl<'a, S> FromToolCallContextPart<'a, S> for Callee<'a, S> { +impl FromToolCallContextPart for CancellationToken { fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { - Ok((Callee(context.service), context)) + context: &mut ToolCallContext, + ) -> Result { + Ok(context.request_context.ct.clone()) } } pub struct ToolName(pub Cow<'static, str>); -impl<'a, S> FromToolCallContextPart<'a, S> for ToolName { +impl FromToolCallContextPart for ToolName { fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { - Ok((Self(context.name.clone()), context)) + context: &mut ToolCallContext, + ) -> Result { + Ok(Self(context.name.clone())) } } -impl<'a, S> FromToolCallContextPart<'a, S> for &'a S { - fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { - Ok((context.service, context)) - } -} - -impl<'a, S, K, V> FromToolCallContextPart<'a, S> for Parameter +impl FromToolCallContextPart for Parameter where K: ConstString, V: DeserializeOwned, { fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { + context: &mut ToolCallContext, + ) -> Result { let arguments = context .arguments .as_ref() @@ -255,17 +234,17 @@ where None, ) })?; - Ok((Parameter(K::default(), value), context)) + Ok(Parameter(K::default(), value)) } } -impl<'a, S, P> FromToolCallContextPart<'a, S> for Parameters

+impl FromToolCallContextPart for Parameters

where P: DeserializeOwned, { fn from_tool_call_context_part( - mut context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { + context: &mut ToolCallContext, + ) -> Result { let arguments = context.arguments.take().unwrap_or_default(); let value: P = serde_json::from_value(serde_json::Value::Object(arguments)).map_err(|e| { @@ -274,37 +253,37 @@ where None, ) })?; - Ok((Parameters(value), context)) + Ok(Parameters(value)) } } -impl<'a, S> FromToolCallContextPart<'a, S> for JsonObject { +impl FromToolCallContextPart for JsonObject { fn from_tool_call_context_part( - mut context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { + context: &mut ToolCallContext, + ) -> Result { let object = context.arguments.take().unwrap_or_default(); - Ok((object, context)) + Ok(object) } } -impl<'a, S> FromToolCallContextPart<'a, S> for crate::model::Extensions { +impl FromToolCallContextPart for crate::model::Extensions { fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { + context: &mut ToolCallContext, + ) -> Result { let extensions = context.request_context.extensions.clone(); - Ok((extensions, context)) + Ok(extensions) } } pub struct Extension(pub T); -impl<'a, S, T> FromToolCallContextPart<'a, S> for Extension +impl FromToolCallContextPart for Extension where T: Send + Sync + 'static + Clone, { fn from_tool_call_context_part( - context: ToolCallContext<'a, S>, - ) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> { + context: &mut ToolCallContext, + ) -> Result { let extension = context .request_context .extensions @@ -316,14 +295,14 @@ where None, ) })?; - Ok((Extension(extension), context)) + Ok(Extension(extension)) } } -impl<'s, S> ToolCallContext<'s, S> { +impl< S> ToolCallContext { pub fn invoke(self, h: H) -> H::Fut where - H: CallToolHandler<'s, S, A>, + H: CallToolHandler, { h.call(self) } @@ -333,6 +312,10 @@ impl<'s, S> ToolCallContext<'s, S> { pub struct AsyncAdapter(PhantomData<(fn(P) -> Fut, fn(Fut) -> R)>); pub struct SyncAdapter(PhantomData R>); +#[allow(clippy::type_complexity)] +pub struct AsyncMethodAdapter(PhantomData<(fn(P) -> Fut, fn(Fut) -> R)>); +pub struct SyncMethodAdapter(PhantomData R>); + macro_rules! impl_for { ($($T: ident)*) => { impl_for!([] [$($T)*]); @@ -346,26 +329,26 @@ macro_rules! impl_for { impl_for!([$($Tn)* $Tn_1] [$($Rest)*]); }; (@impl $($Tn: ident)*) => { - impl<'s, $($Tn,)* S, F, Fut, R> CallToolHandler<'s, S, AsyncAdapter<($($Tn,)*), Fut, R>> for F + impl<$($Tn,)* S, F, Fut, R> CallToolHandler> for F where $( - $Tn: FromToolCallContextPart<'s, S> + 's, + $Tn: FromToolCallContextPart , )* - F: FnOnce($($Tn,)*) -> Fut + Send + 's, - Fut: Future + Send + 's, - R: IntoCallToolResult + Send + 's, + F: FnOnce($($Tn,)*) -> Fut + Send + , + Fut: Future + Send + , + R: IntoCallToolResult + Send + , S: Send + Sync, { type Fut = IntoCallToolResultFut; #[allow(unused_variables, non_snake_case)] fn call( self, - context: ToolCallContext<'s, S>, + mut context: ToolCallContext, ) -> Self::Fut { $( - let result = $Tn::from_tool_call_context_part(context); - let ($Tn, context) = match result { - Ok((value, context)) => (value, context), + let result = $Tn::from_tool_call_context_part(&mut context); + let $Tn = match result { + Ok(value) => value, Err(e) => return IntoCallToolResultFut::Ready { result: std::future::ready(Err(e)), }, @@ -378,25 +361,25 @@ macro_rules! impl_for { } } - impl<'s, $($Tn,)* S, F, R> CallToolHandler<'s, S, SyncAdapter<($($Tn,)*), R>> for F + impl<$($Tn,)* S, F, R> CallToolHandler> for F where $( - $Tn: FromToolCallContextPart<'s, S> + 's, + $Tn: FromToolCallContextPart + , )* - F: FnOnce($($Tn,)*) -> R + Send + 's, - R: IntoCallToolResult + Send + 's, + F: FnOnce($($Tn,)*) -> R + Send + , + R: IntoCallToolResult + Send + , S: Send + Sync, { type Fut = Ready>; #[allow(unused_variables, non_snake_case)] fn call( self, - context: ToolCallContext<'s, S>, + mut context: ToolCallContext, ) -> Self::Fut { $( - let result = $Tn::from_tool_call_context_part(context); - let ($Tn, context) = match result { - Ok((value, context)) => (value, context), + let result = $Tn::from_tool_call_context_part(&mut context); + let $Tn = match result { + Ok(value) => value, Err(e) => return std::future::ready(Err(e)), }; )* @@ -415,7 +398,7 @@ pub struct ToolBoxItem { impl ToolBoxItem { pub fn new(attr: crate::model::Tool, call: C) -> Self where - C: Fn(ToolCallContext<'_, S>) -> BoxFuture<'_, Result> + C: Fn(ToolCallContext) -> BoxFuture<'static, Result> + Send + Sync + 'static, @@ -452,7 +435,7 @@ impl ToolBox { pub async fn call( &self, - context: ToolCallContext<'_, S>, + context: ToolCallContext, ) -> Result { let item = self .map From dd4d2d6dca0f8a4549d6bddf148a1bfc95e02dd1 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Tue, 20 May 2025 16:41:03 +0800 Subject: [PATCH 2/3] draft: axum style router --- crates/rmcp-macros/src/tool.rs | 22 +- crates/rmcp/src/handler/server.rs | 1 + crates/rmcp/src/handler/server/router.rs | 94 +++++++ crates/rmcp/src/handler/server/router/tool.rs | 212 ++++++++++++++ crates/rmcp/src/handler/server/tool.rs | 261 +++++++++++------- crates/rmcp/tests/test_tool_routers.rs | 98 +++++++ 6 files changed, 573 insertions(+), 115 deletions(-) create mode 100644 crates/rmcp/src/handler/server/router.rs create mode 100644 crates/rmcp/src/handler/server/router/tool.rs create mode 100644 crates/rmcp/tests/test_tool_routers.rs diff --git a/crates/rmcp-macros/src/tool.rs b/crates/rmcp-macros/src/tool.rs index c17ef2f..2d74ade 100644 --- a/crates/rmcp-macros/src/tool.rs +++ b/crates/rmcp-macros/src/tool.rs @@ -542,22 +542,22 @@ pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Resu // generate wrapped tool function let tool_call_fn = { // wrapper function have the same sig: - // async fn #tool_tool_call(context: rmcp::handler::server::tool::ToolCallContext<'_, Self>) + // async fn #tool_tool_call(context: rmcp::handler::server::tool::ToolCallContext) // -> std::result::Result // // and the block part should be like: // { // use rmcp::handler::server::tool::*; - // let (t0, context) = ::from_tool_call_context_part(context)?; - // let (t1, context) = ::from_tool_call_context_part(context)?; + // let t0 = ::from_tool_call_context_part(&mut context)?; + // let t1 = ::from_tool_call_context_part(&mut context)?; // ... - // let (tn, context) = ::from_tool_call_context_part(context)?; + // let tn = ::from_tool_call_context_part(&mut context)?; // // for params // ... expand helper types here - // let (__rmcp_tool_req, context) = rmcp::model::JsonObject::from_tool_call_context_part(context)?; + // let __rmcp_tool_req = rmcp::model::JsonObject::from_tool_call_context_part(&mut context)?; // let __#TOOL_ToolCallParam { param_0, param_1, param_2, .. } = parse_json_object(__rmcp_tool_req)?; // // for aggr - // let (Parameters(aggr), context) = >::from_tool_call_context_part(context)?; + // let Parameters(aggr) = >::from_tool_call_context_part(&mut context)?; // Self::#tool_ident(to, param_0, t1, param_1, ..., param_2, tn, aggr).await.into_call_tool_result() // // } @@ -584,14 +584,14 @@ pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Resu let pat = &pat_type.pat; let ty = &pat_type.ty; quote! { - let (#pat, context) = <#ty>::from_tool_call_context_part(context)?; + let #pat = <#ty>::from_tool_call_context_part(&mut context)?; } } FnArg::Receiver(r) => { let ty = r.ty.clone(); let pat = receiver_ident(); quote! { - let (#pat, context) = <#ty>::from_tool_call_context_part(context)?; + let #pat = <#ty>::from_tool_call_context_part(&mut context)?; } } }; @@ -605,7 +605,7 @@ pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Resu ToolParams::Aggregated { rust_type } => { let PatType { pat, ty, .. } = rust_type; quote! { - let (Parameters(#pat), context) = >::from_tool_call_context_part(context)?; + let Parameters(#pat) = >::from_tool_call_context_part(&mut context)?; } } ToolParams::Params { attrs } => { @@ -615,7 +615,7 @@ pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Resu let params_ident = attrs.iter().map(|attr| &attr.ident).collect::>(); quote! { #param_type - let (__rmcp_tool_req, context) = rmcp::model::JsonObject::from_tool_call_context_part(context)?; + let __rmcp_tool_req = rmcp::model::JsonObject::from_tool_call_context_part(&mut context)?; let #temp_param_type_name { #(#params_ident,)* } = parse_json_object(__rmcp_tool_req)?; @@ -669,7 +669,7 @@ pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Resu .collect::>(); quote! { #(#raw_fn_attr)* - #raw_fn_vis async fn #tool_call_fn_ident(context: rmcp::handler::server::tool::ToolCallContext<'_, Self>) + #raw_fn_vis async fn #tool_call_fn_ident(context: rmcp::handler::server::tool::ToolCallContext) -> std::result::Result { use rmcp::handler::server::tool::*; #trivial_arg_extraction_part diff --git a/crates/rmcp/src/handler/server.rs b/crates/rmcp/src/handler/server.rs index 83e9e57..9d55a38 100644 --- a/crates/rmcp/src/handler/server.rs +++ b/crates/rmcp/src/handler/server.rs @@ -7,6 +7,7 @@ use crate::{ mod resource; pub mod tool; pub mod wrapper; +pub mod router; impl Service for H { async fn handle_request( &self, diff --git a/crates/rmcp/src/handler/server/router.rs b/crates/rmcp/src/handler/server/router.rs new file mode 100644 index 0000000..9026b1f --- /dev/null +++ b/crates/rmcp/src/handler/server/router.rs @@ -0,0 +1,94 @@ +use std::sync::Arc; + +use tool::{IntoToolRoute, ToolRoute}; + +use crate::{ + RoleServer, Service, + model::{ClientRequest, ListToolsResult, ServerResult}, +}; + +use super::ServerHandler; + +pub mod tool; + +pub struct Router { + pub tool_router: tool::ToolRouter, + pub service: Arc, +} + +impl Router +where + S: ServerHandler, +{ + pub fn new(service: S) -> Self { + Self { + tool_router: tool::ToolRouter::new(), + service: Arc::new(service), + } + } + + pub fn with_tool(mut self, route: R) -> Self + where + R: IntoToolRoute, + { + self.tool_router.add(route.into_tool_route()); + self + } + + pub fn with_tools(mut self, routes: impl IntoIterator>) -> Self + { + for route in routes { + self.tool_router.add(route); + } + self + } +} + +impl Service for Router +where + S: ServerHandler, +{ + async fn handle_notification( + &self, + notification: ::PeerNot, + ) -> Result<(), crate::Error> { + self.service.handle_notification(notification).await + } + async fn handle_request( + &self, + request: ::PeerReq, + context: crate::service::RequestContext, + ) -> Result<::Resp, crate::Error> { + match request { + ClientRequest::CallToolRequest(request) => { + if self.tool_router.has(request.params.name.as_ref()) + || !self.tool_router.transparent_when_not_found + { + let tool_call_context = crate::handler::server::tool::ToolCallContext::new( + self.service.clone(), + request.params, + context, + ); + let result = self.tool_router.call(tool_call_context).await?; + Ok(ServerResult::CallToolResult(result)) + } else { + self.service + .handle_request(ClientRequest::CallToolRequest(request), context) + .await + } + } + ClientRequest::ListToolsRequest(_) => { + let tools = self.tool_router.list_all(); + Ok(ServerResult::ListToolsResult(ListToolsResult { + tools, + next_cursor: None, + })) + } + rest => self.service.handle_request(rest, context).await, + } + } + + fn get_info(&self) -> ::Info { + self.service.get_info() + } +} diff --git a/crates/rmcp/src/handler/server/router/tool.rs b/crates/rmcp/src/handler/server/router/tool.rs new file mode 100644 index 0000000..8d7a8e8 --- /dev/null +++ b/crates/rmcp/src/handler/server/router/tool.rs @@ -0,0 +1,212 @@ +use std::borrow::Cow; + +use futures::future::BoxFuture; +use schemars::JsonSchema; + +use crate::model::{CallToolResult, Tool, ToolAnnotations}; + +use crate::handler::server::tool::{ + CallToolHandler, DynCallToolHandler, ToolCallContext, schema_for_type, +}; + +pub struct ToolRoute { + #[allow(clippy::type_complexity)] + pub call: Box>, + pub attr: crate::model::Tool, +} + +impl ToolRoute { + pub fn new(attr: impl Into, call: C) -> Self + where + C: CallToolHandler + Send + Sync + Clone + 'static, + >::Fut: 'static, + { + Self { + call: Box::new(move |context: ToolCallContext| { + let call = call.clone(); + Box::pin(async move { context.invoke(call).await }) + }), + attr: attr.into(), + } + } + pub fn new_dyn(attr: impl Into, call: C) -> Self + where + C: Fn(ToolCallContext) -> BoxFuture<'static, Result> + + Send + + Sync + + 'static, + { + Self { + call: Box::new(call), + attr: attr.into(), + } + } + pub fn name(&self) -> &str { + &self.attr.name + } +} + +pub trait IntoToolRoute { + fn into_tool_route(self) -> ToolRoute; +} + +impl IntoToolRoute for (T, C) +where + S: Send + Sync + 'static, + C: CallToolHandler + Send + Sync + Clone + 'static, + T: Into, + >::Fut: 'static, +{ + fn into_tool_route(self) -> ToolRoute { + ToolRoute::new(self.0.into(), self.1) + } +} + +impl IntoToolRoute for ToolRoute +where + S: Send + Sync + 'static, +{ + fn into_tool_route(self) -> ToolRoute { + self + } +} + +pub struct ToolAttrGenerateFunctionAdapter; +impl IntoToolRoute for F +where + S: Send + Sync + 'static, + F: Fn() -> ToolRoute, +{ + fn into_tool_route(self) -> ToolRoute { + (self)() + } +} + +pub trait CallToolHandlerExt: Sized +where + Self: CallToolHandler + Send + Sync + Clone + 'static, + >::Fut: 'static, +{ + fn name(self, name: impl Into>) -> WithToolAttr; +} + +impl CallToolHandlerExt for C +where + C: CallToolHandler + Send + Sync + Clone + 'static, + >::Fut: 'static, +{ + fn name(self, name: impl Into>) -> WithToolAttr { + WithToolAttr { + attr: Tool::new( + name.into(), + "", + schema_for_type::(), + ), + call: self, + _marker: std::marker::PhantomData, + } + } +} + +pub struct WithToolAttr +where + C: CallToolHandler + Send + Sync + Clone + 'static, + >::Fut: 'static, +{ + pub attr: crate::model::Tool, + pub call: C, + pub _marker: std::marker::PhantomData, +} + +impl IntoToolRoute for WithToolAttr +where + C: CallToolHandler + Send + Sync + Clone + 'static, + >::Fut: 'static, + S: Send + Sync + 'static, +{ + fn into_tool_route(self) -> ToolRoute { + ToolRoute::new(self.attr, self.call) + } +} + +impl WithToolAttr +where + C: CallToolHandler + Send + Sync + Clone + 'static, + >::Fut: 'static, +{ + pub fn description(mut self, description: impl Into>) -> Self { + self.attr.description = Some(description.into()); + self + } + pub fn parameters(mut self) -> Self { + self.attr.input_schema = schema_for_type::().into(); + self + } + pub fn parameters_value(mut self, schema: serde_json::Value) -> Self { + self.attr.input_schema = crate::model::object(schema).into(); + self + } + pub fn annotation(mut self, annotation: impl Into) -> Self { + self.attr.annotations = Some(annotation.into()); + self + } +} + +#[derive(Default)] +pub struct ToolRouter { + #[allow(clippy::type_complexity)] + pub map: std::collections::HashMap, ToolRoute>, + + pub transparent_when_not_found: bool, +} + +impl IntoIterator for ToolRouter { + type Item = ToolRoute; + type IntoIter = std::collections::hash_map::IntoValues, ToolRoute>; + + fn into_iter(self) -> Self::IntoIter { + self.map.into_values() + } +} + +impl ToolRouter +where + S: Send + Sync + 'static, +{ + pub fn new() -> Self { + Self { + map: std::collections::HashMap::new(), + transparent_when_not_found: false, + } + } + pub fn with(mut self, attr: crate::model::Tool, call: C) -> Self + where + C: CallToolHandler + Send + Sync + Clone + 'static, + >::Fut: 'static, + { + self.add(ToolRoute::new(attr, call)); + self + } + + pub fn add(&mut self, item: ToolRoute) { + self.map.insert(item.attr.name.clone(), item); + } + + pub fn remove(&mut self, name: &str) { + self.map.remove(name); + } + pub fn has(&self, name: &str) -> bool { + self.map.contains_key(name) + } + pub async fn call(&self, context: ToolCallContext) -> Result { + let item = self + .map + .get(context.name()) + .ok_or_else(|| crate::Error::invalid_params("tool not found", None))?; + (item.call)(context).await + } + + pub fn list_all(&self) -> Vec { + self.map.values().map(|item| item.attr.clone()).collect() + } +} diff --git a/crates/rmcp/src/handler/server/tool.rs b/crates/rmcp/src/handler/server/tool.rs index aa89c2c..e385468 100644 --- a/crates/rmcp/src/handler/server/tool.rs +++ b/crates/rmcp/src/handler/server/tool.rs @@ -12,6 +12,7 @@ use crate::{ model::{CallToolRequestParam, CallToolResult, ConstString, IntoContents, JsonObject}, service::RequestContext, }; + /// A shortcut for generating a JSON schema for a type. pub fn schema_for_type() -> JsonObject { let mut settings = schemars::r#gen::SchemaSettings::default(); @@ -64,10 +65,10 @@ pub fn parse_json_object(input: JsonObject) -> Result { - request_context: RequestContext, - service: Arc, - name: Cow<'static, str>, - arguments: Option, + pub request_context: RequestContext, + pub service: Arc, + pub name: Cow<'static, str>, + pub arguments: Option, } impl ToolCallContext { @@ -168,6 +169,14 @@ pub trait CallToolHandler { pub type DynCallToolHandler = dyn Fn(ToolCallContext) -> BoxFuture<'static, Result> + Send + Sync; + +impl FromToolCallContextPart for Arc { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result { + let service = context.service.clone(); + Ok(service) + } +} + /// Parameter Extractor pub struct Parameter(pub K, pub V); @@ -188,9 +197,7 @@ impl JsonSchema for Parameters

{ } impl FromToolCallContextPart for CancellationToken { - fn from_tool_call_context_part( - context: &mut ToolCallContext, - ) -> Result { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result { Ok(context.request_context.ct.clone()) } } @@ -198,9 +205,7 @@ impl FromToolCallContextPart for CancellationToken { pub struct ToolName(pub Cow<'static, str>); impl FromToolCallContextPart for ToolName { - fn from_tool_call_context_part( - context: &mut ToolCallContext, - ) -> Result { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result { Ok(Self(context.name.clone())) } } @@ -210,9 +215,7 @@ where K: ConstString, V: DeserializeOwned, { - fn from_tool_call_context_part( - context: &mut ToolCallContext, - ) -> Result { + fn from_tool_call_context_part(context: &mut ToolCallContext) -> Result { let arguments = context .arguments .as_ref() @@ -312,8 +315,8 @@ impl< S> ToolCallContext { pub struct AsyncAdapter(PhantomData<(fn(P) -> Fut, fn(Fut) -> R)>); pub struct SyncAdapter(PhantomData R>); -#[allow(clippy::type_complexity)] -pub struct AsyncMethodAdapter(PhantomData<(fn(P) -> Fut, fn(Fut) -> R)>); +// #[allow(clippy::type_complexity)] +// pub struct AsyncMethodAdapter(PhantomData<(fn(P) -> Fut, fn(Fut) -> R)>); pub struct SyncMethodAdapter(PhantomData R>); macro_rules! impl_for { @@ -329,6 +332,33 @@ macro_rules! impl_for { impl_for!([$($Tn)* $Tn_1] [$($Rest)*]); }; (@impl $($Tn: ident)*) => { + // impl<'s, $($Tn,)* S, F, Fut, R> CallToolHandler> for F + // where + // $( + // $Tn: FromToolCallContextPart , + // )* + // F: FnOnce(&'s S, $($Tn,)*) -> Fut + Send + 'static, + // Fut: Future + Send, + // R: IntoCallToolResult + Send, + // S: Send + Sync + 'static, + // { + // type Fut = BoxFuture<'static, Result>; + // #[allow(unused_variables, non_snake_case, unused_mut)] + // fn call( + // self, + // mut context: ToolCallContext, + // ) -> Self::Fut { + // Box::pin(async move { + // $( + // let $Tn = $Tn::from_tool_call_context_part(&mut context)?; + // )* + // let service = context.service.as_ref(); + // self(service, $($Tn,)*).await.into_call_tool_result() + + // }) + // } + // } + impl<$($Tn,)* S, F, Fut, R> CallToolHandler> for F where $( @@ -340,7 +370,7 @@ macro_rules! impl_for { S: Send + Sync, { type Fut = IntoCallToolResultFut; - #[allow(unused_variables, non_snake_case)] + #[allow(unused_variables, non_snake_case, unused_mut)] fn call( self, mut context: ToolCallContext, @@ -361,6 +391,32 @@ macro_rules! impl_for { } } + impl<$($Tn,)* S, F, R> CallToolHandler> for F + where + $( + $Tn: FromToolCallContextPart + , + )* + F: FnOnce(&S, $($Tn,)*) -> R + Send + , + R: IntoCallToolResult + Send + , + S: Send + Sync, + { + type Fut = Ready>; + #[allow(unused_variables, non_snake_case, unused_mut)] + fn call( + self, + mut context: ToolCallContext, + ) -> Self::Fut { + $( + let result = $Tn::from_tool_call_context_part(&mut context); + let $Tn = match result { + Ok(value) => value, + Err(e) => return std::future::ready(Err(e)), + }; + )* + std::future::ready(self(context.service.as_ref(), $($Tn,)*).into_call_tool_result()) + } + } + impl<$($Tn,)* S, F, R> CallToolHandler> for F where $( @@ -371,7 +427,7 @@ macro_rules! impl_for { S: Send + Sync, { type Fut = Ready>; - #[allow(unused_variables, non_snake_case)] + #[allow(unused_variables, non_snake_case, unused_mut)] fn call( self, mut context: ToolCallContext, @@ -413,90 +469,87 @@ impl ToolBoxItem { } } -#[derive(Default)] -pub struct ToolBox { - #[allow(clippy::type_complexity)] - pub map: std::collections::HashMap, ToolBoxItem>, -} - -impl ToolBox { - pub fn new() -> Self { - Self { - map: std::collections::HashMap::new(), - } - } - pub fn add(&mut self, item: ToolBoxItem) { - self.map.insert(item.attr.name.clone(), item); - } - - pub fn remove(&mut self, name: &str) { - self.map.remove(name); - } - - pub async fn call( - &self, - context: ToolCallContext, - ) -> Result { - let item = self - .map - .get(context.name()) - .ok_or_else(|| crate::Error::invalid_params("tool not found", None))?; - (item.call)(context).await - } - - pub fn list(&self) -> Vec { - self.map.values().map(|item| item.attr.clone()).collect() - } -} - -#[cfg(feature = "macros")] -#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] -#[macro_export] -macro_rules! tool_box { - (@pin_add $callee: ident, $attr: expr, $f: expr) => { - $callee.add(ToolBoxItem::new($attr, |context| Box::pin($f(context)))); - }; - ($server: ident { $($tool: ident),* $(,)?} ) => { - $crate::tool_box!($server { $($tool),* } tool_box); - }; - ($server: ident { $($tool: ident),* $(,)?} $tool_box: ident) => { - fn $tool_box() -> &'static $crate::handler::server::tool::ToolBox<$server> { - use $crate::handler::server::tool::{ToolBox, ToolBoxItem}; - static TOOL_BOX: std::sync::OnceLock> = std::sync::OnceLock::new(); - TOOL_BOX.get_or_init(|| { - let mut tool_box = ToolBox::new(); - $crate::paste!{ - $( - $crate::tool_box!(@pin_add tool_box, $server::[< $tool _tool_attr>](), $server::[<$tool _tool_call>]); - )* - } - tool_box - }) - } - }; - (@derive) => { - $crate::tool_box!(@derive tool_box); - }; - - (@derive $tool_box:ident) => { - async fn list_tools( - &self, - _: Option<$crate::model::PaginatedRequestParam>, - _: $crate::service::RequestContext<$crate::service::RoleServer>, - ) -> Result<$crate::model::ListToolsResult, $crate::Error> { - Ok($crate::model::ListToolsResult { - next_cursor: None, - tools: Self::tool_box().list(), - }) - } - - async fn call_tool( - &self, - call_tool_request_param: $crate::model::CallToolRequestParam, - context: $crate::service::RequestContext<$crate::service::RoleServer>, - ) -> Result<$crate::model::CallToolResult, $crate::Error> { - let context = $crate::handler::server::tool::ToolCallContext::new(self, call_tool_request_param, context); - Self::$tool_box().call(context).await - } - } -} +// #[derive(Default)] +// pub struct ToolBox { +// #[allow(clippy::type_complexity)] +// pub map: std::collections::HashMap, ToolBoxItem>, +// } + +// impl ToolBox { +// pub fn new() -> Self { +// Self { +// map: std::collections::HashMap::new(), +// } +// } +// pub fn add(&mut self, item: ToolBoxItem) { +// self.map.insert(item.attr.name.clone(), item); +// } + +// pub fn remove(&mut self, name: &str) { +// self.map.remove(name); +// } + +// pub async fn call(&self, context: ToolCallContext) -> Result { +// let item = self +// .map +// .get(context.name()) +// .ok_or_else(|| crate::Error::invalid_params("tool not found", None))?; +// (item.call)(context).await +// } + +// pub fn list(&self) -> Vec { +// self.map.values().map(|item| item.attr.clone()).collect() +// } +// } + +// #[cfg(feature = "macros")] +// #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +// #[macro_export] +// macro_rules! tool_box { +// (@pin_add $callee: ident, $attr: expr, $f: expr) => { +// $callee.add(ToolBoxItem::new($attr, |context| Box::pin($f(context)))); +// }; +// ($server: ident { $($tool: ident),* $(,)?} ) => { +// $crate::tool_box!($server { $($tool),* } tool_box); +// }; +// ($server: ident { $($tool: ident),* $(,)?} $tool_box: ident) => { +// fn $tool_box() -> &'static $crate::handler::server::tool::ToolBox<$server> { +// use $crate::handler::server::tool::{ToolBox, ToolBoxItem}; +// static TOOL_BOX: std::sync::OnceLock> = std::sync::OnceLock::new(); +// TOOL_BOX.get_or_init(|| { +// let mut tool_box = ToolBox::new(); +// $crate::paste!{ +// $( +// $crate::tool_box!(@pin_add tool_box, $server::[< $tool _tool_attr>](), $server::[<$tool _tool_call>]); +// )* +// } +// tool_box +// }) +// } +// }; +// (@derive) => { +// $crate::tool_box!(@derive tool_box); +// }; + +// (@derive $tool_box:ident) => { +// async fn list_tools( +// &self, +// _: Option<$crate::model::PaginatedRequestParam>, +// _: $crate::service::RequestContext<$crate::service::RoleServer>, +// ) -> Result<$crate::model::ListToolsResult, $crate::Error> { +// Ok($crate::model::ListToolsResult { +// next_cursor: None, +// tools: Self::tool_box().list(), +// }) +// } + +// async fn call_tool( +// &self, +// call_tool_request_param: $crate::model::CallToolRequestParam, +// context: $crate::service::RequestContext<$crate::service::RoleServer>, +// ) -> Result<$crate::model::CallToolResult, $crate::Error> { +// let context = $crate::handler::server::tool::ToolCallContext::new(self, call_tool_request_param, context); +// Self::$tool_box().call(context).await +// } +// } +// } diff --git a/crates/rmcp/tests/test_tool_routers.rs b/crates/rmcp/tests/test_tool_routers.rs new file mode 100644 index 0000000..8abcc2d --- /dev/null +++ b/crates/rmcp/tests/test_tool_routers.rs @@ -0,0 +1,98 @@ +use std::{collections::HashMap, sync::Arc}; + +use rmcp::{ + RoleServer, ServerHandler, Service, + handler::server::{ + router::{ + Router, + tool::{CallToolHandlerExt, ToolRoute, ToolRouter}, + }, + tool::{Parameters, schema_for_type}, + }, + model::{Extensions, Tool}, +}; + +#[derive(Debug, Default)] +pub struct TestHandler { + pub _marker: std::marker::PhantomData, +} + +impl ServerHandler for TestHandler {} +#[derive(Debug, schemars::JsonSchema, serde::Deserialize, serde::Serialize)] +pub struct Request { + pub fields: HashMap, +} + +#[derive(Debug, schemars::JsonSchema, serde::Deserialize, serde::Serialize)] +pub struct Sum { + pub a: i32, + pub b: i32, +} + +impl TestHandler { + async fn async_method(self: Arc, Parameters(Request { fields }): Parameters) { + drop(fields) + } + fn sync_method(&self, Parameters(Request { fields }): Parameters) { + drop(fields) + } +} + +fn sync_function(Parameters(Request { fields }): Parameters) { + drop(fields) +} + +// #[rmcp(tool(description = "async method", parameters = Request, name = "async_method"))] +// ^ +// |_____ this is a macro will generates a function with the same name but return ToolRoute +fn async_function( + _callee: Arc>, + Parameters(Request { fields }): Parameters, +) { + drop(fields) +} + +fn attr_generator_fn() -> ToolRoute> { + ToolRoute::new( + Tool::new( + "sync_method_from_generator_fn", + "a sync method tool", + schema_for_type::(), + ), + TestHandler::sync_method, + ) +} + +fn assert_service>(service: S) { + drop(service); +} + +#[test] +fn test_tool_router() { + let test_handler = TestHandler::<()>::default(); + fn tool(name: &'static str) -> Tool { + Tool::new(name, name, schema_for_type::()) + } + let tool_router = ToolRouter::>::new() + .with(tool("sync_method"), TestHandler::sync_method) + .with(tool("async_method"), TestHandler::async_method) + .with(tool("sync_function"), sync_function) + .with(tool("async_function"), async_function); + + let router = Router::new(test_handler) + .with_tool( + TestHandler::sync_method + .name("sync_method") + .description("a sync method tool") + .parameters::(), + ) + .with_tool( + (|Parameters(Sum { a, b }): Parameters| (a + b).to_string()) + .name("add") + .parameters::(), + ) + .with_tool(attr_generator_fn) + .with_tools(tool_router); + + assert_service(router); +} From 8b638ea1013a0b920e3cd0dd5e5524acdd99f3ff Mon Sep 17 00:00:00 2001 From: 4t145 Date: Mon, 26 May 2025 10:16:43 +0800 Subject: [PATCH 3/3] sync --- crates/rmcp/src/handler/server/router/promt.rs | 0 crates/rmcp/src/handler/server/tool.rs | 2 +- crates/rmcp/tests/test_tool_routers.rs | 5 ++--- 3 files changed, 3 insertions(+), 4 deletions(-) create mode 100644 crates/rmcp/src/handler/server/router/promt.rs diff --git a/crates/rmcp/src/handler/server/router/promt.rs b/crates/rmcp/src/handler/server/router/promt.rs new file mode 100644 index 0000000..e69de29 diff --git a/crates/rmcp/src/handler/server/tool.rs b/crates/rmcp/src/handler/server/tool.rs index e385468..95b4e62 100644 --- a/crates/rmcp/src/handler/server/tool.rs +++ b/crates/rmcp/src/handler/server/tool.rs @@ -312,7 +312,7 @@ impl< S> ToolCallContext { } #[allow(clippy::type_complexity)] -pub struct AsyncAdapter(PhantomData<(fn(P) -> Fut, fn(Fut) -> R)>); +pub struct AsyncAdapter(PhantomData fn(Fut) -> R>); pub struct SyncAdapter(PhantomData R>); // #[allow(clippy::type_complexity)] diff --git a/crates/rmcp/tests/test_tool_routers.rs b/crates/rmcp/tests/test_tool_routers.rs index 8abcc2d..d972a47 100644 --- a/crates/rmcp/tests/test_tool_routers.rs +++ b/crates/rmcp/tests/test_tool_routers.rs @@ -52,14 +52,14 @@ fn async_function( drop(fields) } -fn attr_generator_fn() -> ToolRoute> { +fn attr_generator_fn() -> ToolRoute { ToolRoute::new( Tool::new( "sync_method_from_generator_fn", "a sync method tool", schema_for_type::(), ), - TestHandler::sync_method, + sync_function, ) } @@ -93,6 +93,5 @@ fn test_tool_router() { ) .with_tool(attr_generator_fn) .with_tools(tool_router); - assert_service(router); }