diff --git a/crates/rmcp-macros/src/tool.rs b/crates/rmcp-macros/src/tool.rs index c17ef2f8..40f761a4 100644 --- a/crates/rmcp-macros/src/tool.rs +++ b/crates/rmcp-macros/src/tool.rs @@ -4,8 +4,10 @@ use proc_macro2::TokenStream; use quote::{ToTokens, quote}; use serde_json::json; use syn::{ - Expr, FnArg, Ident, ItemFn, ItemImpl, Lit, MetaList, PatType, Token, Type, Visibility, - parse::Parse, parse_quote, spanned::Spanned, + Expr, FnArg, Ident, ItemFn, ItemImpl, MetaList, PatType, Token, Type, Visibility,Lit, + parse::{Parse, discouraged::Speculative}, + parse_quote, + spanned::Spanned, }; /// Stores tool annotation attributes @@ -42,13 +44,17 @@ impl Parse for ToolAnnotationAttrs { } #[derive(Default)] -struct ToolImplItemAttrs { +pub(crate) struct ToolImplItemAttrs { tool_box: Option>, + default_build: bool, + description: Option, } impl Parse for ToolImplItemAttrs { fn parse(input: syn::parse::ParseStream) -> syn::Result { let mut tool_box = None; + let mut default = true; + let mut description = None; while !input.is_empty() { let key: Ident = input.parse()?; match key.to_string().as_str() { @@ -60,6 +66,32 @@ impl Parse for ToolImplItemAttrs { tool_box = Some(Some(value)); } } + "default_build" => { + if input.lookahead1().peek(Token![=]) { + input.parse::()?; + let value: Expr = input.parse()?; + match value.to_token_stream().to_string().as_str() { + "true" => { + default = true; + } + "false" => { + default = false; + } + _ => { + return Err(syn::Error::new(key.span(), "unknown attribute")); + } + } + } else { + default = true; + } + } + "description" => { + if input.lookahead1().peek(Token![=]) { + input.parse::()?; + let value: Expr = input.parse()?; + description = Some(value); + } + } _ => { return Err(syn::Error::new(key.span(), "unknown attribute")); } @@ -70,7 +102,11 @@ impl Parse for ToolImplItemAttrs { input.parse::()?; } - Ok(ToolImplItemAttrs { tool_box }) + Ok(ToolImplItemAttrs { + tool_box, + default_build: default, + description, + }) } } @@ -79,6 +115,7 @@ struct ToolFnItemAttrs { name: Option, description: Option, vis: Option, + aggr: bool, annotations: Option, } @@ -87,12 +124,18 @@ impl Parse for ToolFnItemAttrs { let mut name = None; let mut description = None; let mut vis = None; + let mut aggr = false; let mut annotations = None; while !input.is_empty() { let key: Ident = input.parse()?; + let key_str = key.to_string(); + if key_str == AGGREGATED_IDENT { + aggr = true; + continue; + } input.parse::()?; - match key.to_string().as_str() { + match key_str.as_str() { "name" => { let value: Expr = input.parse()?; name = Some(value); @@ -126,6 +169,7 @@ impl Parse for ToolFnItemAttrs { name, description, vis, + aggr, annotations, }) } @@ -200,14 +244,20 @@ pub enum ToolItem { impl Parse for ToolItem { fn parse(input: syn::parse::ParseStream) -> syn::Result { - let lookahead = input.lookahead1(); - if lookahead.peek(Token![impl]) { - let item = input.parse::()?; - Ok(ToolItem::Impl(item)) - } else { - let item = input.parse::()?; - Ok(ToolItem::Fn(item)) + let fork = input.fork(); + if let Ok(item) = fork.parse::() { + input.advance_to(&fork); + return Ok(ToolItem::Impl(item)); + } + let fork = input.fork(); + if let Ok(item) = fork.parse::() { + input.advance_to(&fork); + return Ok(ToolItem::Fn(item)); } + Err(syn::Error::new( + input.span(), + "expected function or impl block", + )) } } @@ -223,7 +273,22 @@ pub(crate) fn tool(attr: TokenStream, input: TokenStream) -> syn::Result syn::Result { let tool_impl_attr: ToolImplItemAttrs = syn::parse2(attr)?; let tool_box_ident = tool_impl_attr.tool_box; - + let mut extend_quote = None; + let description = if let Some(expr) = tool_impl_attr.description { + // Use explicitly provided description if available + expr + } else { + // Try to extract documentation comments + let doc_content = input + .attrs + .iter() + .filter_map(extract_doc_line) + .collect::>() + .join("\n"); + parse_quote! { + #doc_content.trim().to_string() + } + }; // get all tool function ident let mut tool_fn_idents = Vec::new(); for item in &input.items { @@ -325,6 +390,37 @@ pub(crate) fn tool_impl_item(attr: TokenStream, mut input: ItemImpl) -> syn::Res }) } }); + + if tool_impl_attr.default_build { + let struct_name = input.self_ty.clone(); + let generic = &input.generics; + let extend = quote! { + impl #generic rmcp::handler::server::ServerHandler for #struct_name { + async fn call_tool( + &self, + request: rmcp::model::CallToolRequestParam, + context: rmcp::service::RequestContext, + ) -> Result { + self.call_tool_inner(request, context).await + } + async fn list_tools( + &self, + request: Option, + context: rmcp::service::RequestContext, + ) -> Result { + self.list_tools_inner(request.unwrap_or_default(), context).await + } + fn get_info(&self) -> rmcp::model::ServerInfo { + rmcp::model::ServerInfo { + instructions: Some(#description.into()), + capabilities: rmcp::model::ServerCapabilities::builder().enable_tools().build(), + ..Default::default() + } + } + } + }; + extend_quote.replace(extend); + } } else { // if there are no generic parameters, use the original tool_box! macro let this_type_ident = &input.self_ty; @@ -333,11 +429,30 @@ pub(crate) fn tool_impl_item(attr: TokenStream, mut input: ItemImpl) -> syn::Res #(#tool_fn_idents),* } #ident); )); + if tool_impl_attr.default_build { + let struct_name = input.self_ty.clone(); + let generic = &input.generics; + let extend = quote! { + impl #generic rmcp::handler::server::ServerHandler for #struct_name { + rmcp::tool_box!(@derive #ident); + + fn get_info(&self) -> rmcp::model::ServerInfo { + rmcp::model::ServerInfo { + instructions: Some(#description.into()), + capabilities: rmcp::model::ServerCapabilities::builder().enable_tools().build(), + ..Default::default() + } + } + } + }; + extend_quote.replace(extend); + } } } Ok(quote! { #input + #extend_quote }) } @@ -391,29 +506,7 @@ pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Resu for attr in raw_attrs { match &attr.meta { syn::Meta::List(meta_list) => { - if meta_list.path.is_ident(TOOL_IDENT) { - let pat_type = pat_type.clone(); - let marker = meta_list.parse_args::()?; - match marker { - ParamMarker::Param => { - let Some(arg_ident) = arg_ident.take() else { - return Err(syn::Error::new( - proc_macro2::Span::call_site(), - "input param must have an ident as name", - )); - }; - caught.replace(Caught::Param(ToolFnParamAttrs { - serde_meta: Vec::new(), - schemars_meta: Vec::new(), - ident: arg_ident, - rust_type: pat_type.ty.clone(), - })); - } - ParamMarker::Aggregated => { - caught.replace(Caught::Aggregated(pat_type.clone())); - } - } - } else if meta_list.path.is_ident(SERDE_IDENT) { + if meta_list.path.is_ident(SERDE_IDENT) { serde_metas.push(meta_list.clone()); } else if meta_list.path.is_ident(SCHEMARS_IDENT) { schemars_metas.push(meta_list.clone()); @@ -426,6 +519,23 @@ pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Resu } } } + let pat_type = pat_type.clone(); + if tool_macro_attrs.fn_item.aggr { + caught.replace(Caught::Aggregated(pat_type.clone())); + } else { + let Some(arg_ident) = arg_ident.take() else { + return Err(syn::Error::new( + proc_macro2::Span::call_site(), + "input param must have an ident as name", + )); + }; + caught.replace(Caught::Param(ToolFnParamAttrs { + serde_meta: Vec::new(), + schemars_meta: Vec::new(), + ident: arg_ident, + rust_type: pat_type.ty.clone(), + })); + } match caught { Some(Caught::Param(mut param)) => { param.serde_meta = serde_metas; @@ -483,7 +593,6 @@ pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Resu .filter_map(extract_doc_line) .collect::>() .join("\n"); - parse_quote! { #doc_content.trim().to_string() } @@ -759,6 +868,7 @@ mod test { // The output should contain the description from doc comments let result_str = result.to_string(); + println!("result: {:#}", result_str); assert!(result_str.contains("This is a test description from doc comments")); assert!(result_str.contains("with multiple lines")); diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 486719d4..532dcd2f 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -59,7 +59,7 @@ tokio-stream = { version = "0.1", optional = true } uuid = { version = "1", features = ["v4"], optional = true } # macro -rmcp-macros = { version = "0.1", workspace = true, optional = true } +rmcp-macros = { workspace = true, optional = true } [features] default = ["base64", "macros", "server"] diff --git a/crates/rmcp/tests/common/calculator.rs b/crates/rmcp/tests/common/calculator.rs index e179f258..7efd3f8e 100644 --- a/crates/rmcp/tests/common/calculator.rs +++ b/crates/rmcp/tests/common/calculator.rs @@ -1,8 +1,4 @@ -use rmcp::{ - ServerHandler, - model::{ServerCapabilities, ServerInfo}, - schemars, tool, -}; +use rmcp::{schemars, tool}; #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] pub struct SumRequest { #[schemars(description = "the left hand side number")] @@ -11,34 +7,19 @@ pub struct SumRequest { } #[derive(Debug, Clone, Default)] pub struct Calculator; -#[tool(tool_box)] +#[tool(tool_box, description = "A simple calculator")] impl Calculator { - #[tool(description = "Calculate the sum of two numbers")] - fn sum(&self, #[tool(aggr)] SumRequest { a, b }: SumRequest) -> String { + #[tool(description = "Calculate the sum of two numbers", aggr)] + fn sum(&self, SumRequest { a, b }: SumRequest) -> String { (a + b).to_string() } #[tool(description = "Calculate the sub of two numbers")] fn sub( &self, - #[tool(param)] - #[schemars(description = "the left hand side number")] - a: i32, - #[tool(param)] - #[schemars(description = "the right hand side number")] - b: i32, + #[schemars(description = "the left hand side number")] a: i32, + #[schemars(description = "the right hand side number")] b: i32, ) -> String { (a - b).to_string() } } - -#[tool(tool_box)] -impl ServerHandler for Calculator { - fn get_info(&self) -> ServerInfo { - ServerInfo { - instructions: Some("A simple calculator".into()), - capabilities: ServerCapabilities::builder().enable_tools().build(), - ..Default::default() - } - } -} diff --git a/crates/rmcp/tests/test_complex_schema.rs b/crates/rmcp/tests/test_complex_schema.rs index b9370fce..b0a52b96 100644 --- a/crates/rmcp/tests/test_complex_schema.rs +++ b/crates/rmcp/tests/test_complex_schema.rs @@ -30,11 +30,8 @@ impl Demo { Self } - #[tool(description = "LLM")] - async fn chat( - &self, - #[tool(aggr)] chat_request: ChatRequest, - ) -> Result { + #[tool(description = "LLM", aggr)] + async fn chat(&self, chat_request: ChatRequest) -> Result { let content = Content::json(chat_request)?; Ok(CallToolResult::success(vec![content])) } diff --git a/crates/rmcp/tests/test_tool_macros.rs b/crates/rmcp/tests/test_tool_macros.rs index 2e7e214c..9556dc7c 100644 --- a/crates/rmcp/tests/test_tool_macros.rs +++ b/crates/rmcp/tests/test_tool_macros.rs @@ -37,7 +37,7 @@ pub struct Server {} impl Server { /// This tool is used to get the weather of a city. #[tool(name = "get-weather", description = "Get the weather of a city.", vis = )] - pub async fn get_weather(&self, #[tool(param)] city: String) -> String { + pub async fn get_weather(&self, city: String) -> String { drop(city); "rain".to_string() } diff --git a/examples/servers/src/common/calculator.rs b/examples/servers/src/common/calculator.rs index 68beecc0..0e856785 100644 --- a/examples/servers/src/common/calculator.rs +++ b/examples/servers/src/common/calculator.rs @@ -1,9 +1,4 @@ -use rmcp::{ - ServerHandler, - handler::server::wrapper::Json, - model::{ServerCapabilities, ServerInfo}, - schemars, tool, -}; +use rmcp::{handler::server::wrapper::Json, schemars, tool}; #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] pub struct SumRequest { @@ -13,34 +8,19 @@ pub struct SumRequest { } #[derive(Debug, Clone)] pub struct Calculator; -#[tool(tool_box)] +#[tool(tool_box, description = "A simple calculator")] impl Calculator { - #[tool(description = "Calculate the sum of two numbers")] - fn sum(&self, #[tool(aggr)] SumRequest { a, b }: SumRequest) -> String { + #[tool(description = "Calculate the sum of two numbers", aggr)] + fn sum(&self, SumRequest { a, b }: SumRequest) -> String { (a + b).to_string() } #[tool(description = "Calculate the difference of two numbers")] fn sub( &self, - #[tool(param)] - #[schemars(description = "the left hand side number")] - a: i32, - #[tool(param)] - #[schemars(description = "the right hand side number")] - b: i32, + #[schemars(description = "the left hand side number")] a: i32, + #[schemars(description = "the right hand side number")] b: i32, ) -> Json { Json(a - b) } } - -#[tool(tool_box)] -impl ServerHandler for Calculator { - fn get_info(&self) -> ServerInfo { - ServerInfo { - instructions: Some("A simple calculator".into()), - capabilities: ServerCapabilities::builder().enable_tools().build(), - ..Default::default() - } - } -} diff --git a/examples/servers/src/common/counter.rs b/examples/servers/src/common/counter.rs index 12aa8a4a..8f53c810 100644 --- a/examples/servers/src/common/counter.rs +++ b/examples/servers/src/common/counter.rs @@ -18,7 +18,7 @@ pub struct Counter { counter: Arc>, } -#[tool(tool_box)] +#[tool(tool_box, default_build = false)] impl Counter { #[allow(dead_code)] pub fn new() -> Self { @@ -65,18 +65,13 @@ impl Counter { #[tool(description = "Repeat what you say")] fn echo( &self, - #[tool(param)] - #[schemars(description = "Repeat what you say")] - saying: String, + #[schemars(description = "Repeat what you say")] saying: String, ) -> Result { Ok(CallToolResult::success(vec![Content::text(saying)])) } - #[tool(description = "Calculate the sum of two numbers")] - fn sum( - &self, - #[tool(aggr)] StructRequest { a, b }: StructRequest, - ) -> Result { + #[tool(description = "Calculate the sum of two numbers", aggr)] + fn sum(&self, StructRequest { a, b }: StructRequest) -> Result { Ok(CallToolResult::success(vec![Content::text( (a + b).to_string(), )])) diff --git a/examples/servers/src/common/generic_service.rs b/examples/servers/src/common/generic_service.rs index 433a4308..a45d3674 100644 --- a/examples/servers/src/common/generic_service.rs +++ b/examples/servers/src/common/generic_service.rs @@ -1,10 +1,6 @@ use std::sync::Arc; -use rmcp::{ - ServerHandler, - model::{ServerCapabilities, ServerInfo}, - schemars, tool, -}; +use rmcp::{schemars, tool}; #[allow(dead_code)] pub trait DataService: Send + Sync + 'static { @@ -42,8 +38,9 @@ pub struct GenericService { data_service: Arc, } -#[tool(tool_box)] +#[tool(tool_box, description = "generic data service")] impl GenericService { + #[allow(dead_code)] pub fn new(data_service: DS) -> Self { Self { data_service: Arc::new(data_service), @@ -56,18 +53,8 @@ impl GenericService { } #[tool(description = "set memory to service")] - pub async fn set_data(&self, #[tool(param)] data: String) -> String { + pub async fn set_data(&self, data: String) -> String { let new_data = data.clone(); format!("Current memory: {}", new_data) } } - -impl ServerHandler for GenericService { - fn get_info(&self) -> ServerInfo { - ServerInfo { - instructions: Some("generic data service".into()), - capabilities: ServerCapabilities::builder().enable_tools().build(), - ..Default::default() - } - } -} diff --git a/examples/simple-chat-client/Cargo.toml b/examples/simple-chat-client/Cargo.toml index 9c84915c..20dc0192 100644 --- a/examples/simple-chat-client/Cargo.toml +++ b/examples/simple-chat-client/Cargo.toml @@ -18,5 +18,5 @@ rmcp = { workspace = true, features = [ "transport-child-process", "transport-sse-client", "reqwest" -], no-default-features = true } +], default-features = false } clap = { version = "4.0", features = ["derive"] } diff --git a/examples/transport/src/common/calculator.rs b/examples/transport/src/common/calculator.rs index 99b7314a..6f99a08c 100644 --- a/examples/transport/src/common/calculator.rs +++ b/examples/transport/src/common/calculator.rs @@ -9,20 +9,16 @@ pub struct SumRequest { #[derive(Debug, Clone)] pub struct Calculator; impl Calculator { - #[tool(description = "Calculate the sum of two numbers")] - fn sum(&self, #[tool(aggr)] SumRequest { a, b }: SumRequest) -> String { + #[tool(description = "Calculate the sum of two numbers", aggr)] + fn sum(&self, SumRequest { a, b }: SumRequest) -> String { (a + b).to_string() } #[tool(description = "Calculate the sub of two numbers")] fn sub( &self, - #[tool(param)] - #[schemars(description = "the left hand side number")] - a: i32, - #[tool(param)] - #[schemars(description = "the right hand side number")] - b: i32, + #[schemars(description = "the left hand side number")] a: i32, + #[schemars(description = "the right hand side number")] b: i32, ) -> String { (a - b).to_string() } diff --git a/examples/wasi/src/calculator.rs b/examples/wasi/src/calculator.rs index f1c35eea..e28b4fc1 100644 --- a/examples/wasi/src/calculator.rs +++ b/examples/wasi/src/calculator.rs @@ -13,20 +13,16 @@ pub struct SumRequest { #[derive(Debug, Clone)] pub struct Calculator; impl Calculator { - #[tool(description = "Calculate the sum of two numbers")] - fn sum(&self, #[tool(aggr)] SumRequest { a, b }: SumRequest) -> String { + #[tool(description = "Calculate the sum of two numbers", aggr)] + fn sum(&self, SumRequest { a, b }: SumRequest) -> String { (a + b).to_string() } #[tool(description = "Calculate the sub of two numbers")] fn sub( &self, - #[tool(param)] - #[schemars(description = "the left hand side number")] - a: i32, - #[tool(param)] - #[schemars(description = "the right hand side number")] - b: i32, + #[schemars(description = "the left hand side number")] a: i32, + #[schemars(description = "the right hand side number")] b: i32, ) -> String { (a - b).to_string() }