Skip to content

Enhance Tool Macro Functionality #161

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 147 additions & 37 deletions crates/rmcp-macros/src/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ use proc_macro2::TokenStream;
use quote::{ToTokens, quote};
use serde_json::json;
use syn::{
Expr, FnArg, Ident, ItemFn, ItemImpl, Lit, MetaList, PatType, Token, Type, Visibility,
parse::Parse, parse_quote, spanned::Spanned,
Expr, FnArg, Ident, ItemFn, ItemImpl, MetaList, PatType, Token, Type, Visibility,Lit,
parse::{Parse, discouraged::Speculative},
parse_quote,
spanned::Spanned,
};

/// Stores tool annotation attributes
Expand Down Expand Up @@ -42,13 +44,17 @@ impl Parse for ToolAnnotationAttrs {
}

#[derive(Default)]
struct ToolImplItemAttrs {
pub(crate) struct ToolImplItemAttrs {
tool_box: Option<Option<Ident>>,
default_build: bool,
description: Option<Expr>,
}

impl Parse for ToolImplItemAttrs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut tool_box = None;
let mut default = true;
let mut description = None;
while !input.is_empty() {
let key: Ident = input.parse()?;
match key.to_string().as_str() {
Expand All @@ -60,6 +66,32 @@ impl Parse for ToolImplItemAttrs {
tool_box = Some(Some(value));
}
}
"default_build" => {
if input.lookahead1().peek(Token![=]) {
input.parse::<Token![=]>()?;
let value: Expr = input.parse()?;
match value.to_token_stream().to_string().as_str() {
"true" => {
default = true;
}
"false" => {
default = false;
}
_ => {
return Err(syn::Error::new(key.span(), "unknown attribute"));
}
}
} else {
default = true;
}
}
"description" => {
if input.lookahead1().peek(Token![=]) {
input.parse::<Token![=]>()?;
let value: Expr = input.parse()?;
description = Some(value);
}
}
_ => {
return Err(syn::Error::new(key.span(), "unknown attribute"));
}
Expand All @@ -70,7 +102,11 @@ impl Parse for ToolImplItemAttrs {
input.parse::<Token![,]>()?;
}

Ok(ToolImplItemAttrs { tool_box })
Ok(ToolImplItemAttrs {
tool_box,
default_build: default,
description,
})
}
}

Expand All @@ -79,6 +115,7 @@ struct ToolFnItemAttrs {
name: Option<Expr>,
description: Option<Expr>,
vis: Option<Visibility>,
aggr: bool,
annotations: Option<ToolAnnotationAttrs>,
}

Expand All @@ -87,12 +124,18 @@ impl Parse for ToolFnItemAttrs {
let mut name = None;
let mut description = None;
let mut vis = None;
let mut aggr = false;
let mut annotations = None;

while !input.is_empty() {
let key: Ident = input.parse()?;
let key_str = key.to_string();
if key_str == AGGREGATED_IDENT {
aggr = true;
continue;
}
input.parse::<Token![=]>()?;
match key.to_string().as_str() {
match key_str.as_str() {
"name" => {
let value: Expr = input.parse()?;
name = Some(value);
Expand Down Expand Up @@ -126,6 +169,7 @@ impl Parse for ToolFnItemAttrs {
name,
description,
vis,
aggr,
annotations,
})
}
Expand Down Expand Up @@ -200,14 +244,20 @@ pub enum ToolItem {

impl Parse for ToolItem {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(Token![impl]) {
let item = input.parse::<ItemImpl>()?;
Ok(ToolItem::Impl(item))
} else {
let item = input.parse::<ItemFn>()?;
Ok(ToolItem::Fn(item))
let fork = input.fork();
if let Ok(item) = fork.parse::<ItemImpl>() {
input.advance_to(&fork);
return Ok(ToolItem::Impl(item));
}
let fork = input.fork();
if let Ok(item) = fork.parse::<ItemFn>() {
input.advance_to(&fork);
return Ok(ToolItem::Fn(item));
}
Err(syn::Error::new(
input.span(),
"expected function or impl block",
))
}
}

Expand All @@ -223,7 +273,22 @@ pub(crate) fn tool(attr: TokenStream, input: TokenStream) -> syn::Result<TokenSt
pub(crate) fn tool_impl_item(attr: TokenStream, mut input: ItemImpl) -> syn::Result<TokenStream> {
let tool_impl_attr: ToolImplItemAttrs = syn::parse2(attr)?;
let tool_box_ident = tool_impl_attr.tool_box;

let mut extend_quote = None;
let description = if let Some(expr) = tool_impl_attr.description {
// Use explicitly provided description if available
expr
} else {
// Try to extract documentation comments
let doc_content = input
.attrs
.iter()
.filter_map(extract_doc_line)
.collect::<Vec<_>>()
.join("\n");
parse_quote! {
#doc_content.trim().to_string()
}
};
// get all tool function ident
let mut tool_fn_idents = Vec::new();
for item in &input.items {
Expand Down Expand Up @@ -325,6 +390,37 @@ pub(crate) fn tool_impl_item(attr: TokenStream, mut input: ItemImpl) -> syn::Res
})
}
});

if tool_impl_attr.default_build {
let struct_name = input.self_ty.clone();
let generic = &input.generics;
let extend = quote! {
impl #generic rmcp::handler::server::ServerHandler for #struct_name {
async fn call_tool(
&self,
request: rmcp::model::CallToolRequestParam,
context: rmcp::service::RequestContext<rmcp::RoleServer>,
) -> Result<rmcp::model::CallToolResult, rmcp::Error> {
self.call_tool_inner(request, context).await
}
async fn list_tools(
&self,
request: Option<rmcp::model::PaginatedRequestParam>,
context: rmcp::service::RequestContext<rmcp::RoleServer>,
) -> Result<rmcp::model::ListToolsResult, rmcp::Error> {
self.list_tools_inner(request.unwrap_or_default(), context).await
}
fn get_info(&self) -> rmcp::model::ServerInfo {
rmcp::model::ServerInfo {
instructions: Some(#description.into()),
capabilities: rmcp::model::ServerCapabilities::builder().enable_tools().build(),
..Default::default()
}
}
}
};
extend_quote.replace(extend);
}
} else {
// if there are no generic parameters, use the original tool_box! macro
let this_type_ident = &input.self_ty;
Expand All @@ -333,11 +429,30 @@ pub(crate) fn tool_impl_item(attr: TokenStream, mut input: ItemImpl) -> syn::Res
#(#tool_fn_idents),*
} #ident);
));
if tool_impl_attr.default_build {
let struct_name = input.self_ty.clone();
let generic = &input.generics;
let extend = quote! {
impl #generic rmcp::handler::server::ServerHandler for #struct_name {
rmcp::tool_box!(@derive #ident);

fn get_info(&self) -> rmcp::model::ServerInfo {
rmcp::model::ServerInfo {
instructions: Some(#description.into()),
capabilities: rmcp::model::ServerCapabilities::builder().enable_tools().build(),
..Default::default()
}
}
}
};
extend_quote.replace(extend);
}
}
}

Ok(quote! {
#input
#extend_quote
})
}

Expand Down Expand Up @@ -391,29 +506,7 @@ pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Resu
for attr in raw_attrs {
match &attr.meta {
syn::Meta::List(meta_list) => {
if meta_list.path.is_ident(TOOL_IDENT) {
let pat_type = pat_type.clone();
let marker = meta_list.parse_args::<ParamMarker>()?;
match marker {
ParamMarker::Param => {
let Some(arg_ident) = arg_ident.take() else {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"input param must have an ident as name",
));
};
caught.replace(Caught::Param(ToolFnParamAttrs {
serde_meta: Vec::new(),
schemars_meta: Vec::new(),
ident: arg_ident,
rust_type: pat_type.ty.clone(),
}));
}
ParamMarker::Aggregated => {
caught.replace(Caught::Aggregated(pat_type.clone()));
}
}
} else if meta_list.path.is_ident(SERDE_IDENT) {
if meta_list.path.is_ident(SERDE_IDENT) {
serde_metas.push(meta_list.clone());
} else if meta_list.path.is_ident(SCHEMARS_IDENT) {
schemars_metas.push(meta_list.clone());
Expand All @@ -426,6 +519,23 @@ pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Resu
}
}
}
let pat_type = pat_type.clone();
if tool_macro_attrs.fn_item.aggr {
caught.replace(Caught::Aggregated(pat_type.clone()));
} else {
let Some(arg_ident) = arg_ident.take() else {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"input param must have an ident as name",
));
};
caught.replace(Caught::Param(ToolFnParamAttrs {
serde_meta: Vec::new(),
schemars_meta: Vec::new(),
ident: arg_ident,
rust_type: pat_type.ty.clone(),
}));
}
match caught {
Some(Caught::Param(mut param)) => {
param.serde_meta = serde_metas;
Expand Down Expand Up @@ -483,7 +593,6 @@ pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Resu
.filter_map(extract_doc_line)
.collect::<Vec<_>>()
.join("\n");

parse_quote! {
#doc_content.trim().to_string()
}
Expand Down Expand Up @@ -759,6 +868,7 @@ mod test {

// The output should contain the description from doc comments
let result_str = result.to_string();
println!("result: {:#}", result_str);
assert!(result_str.contains("This is a test description from doc comments"));
assert!(result_str.contains("with multiple lines"));

Expand Down
2 changes: 1 addition & 1 deletion crates/rmcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ tokio-stream = { version = "0.1", optional = true }
uuid = { version = "1", features = ["v4"], optional = true }

# macro
rmcp-macros = { version = "0.1", workspace = true, optional = true }
rmcp-macros = { workspace = true, optional = true }

[features]
default = ["base64", "macros", "server"]
Expand Down
31 changes: 6 additions & 25 deletions crates/rmcp/tests/common/calculator.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
use rmcp::{
ServerHandler,
model::{ServerCapabilities, ServerInfo},
schemars, tool,
};
use rmcp::{schemars, tool};
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
pub struct SumRequest {
#[schemars(description = "the left hand side number")]
Expand All @@ -11,34 +7,19 @@ pub struct SumRequest {
}
#[derive(Debug, Clone, Default)]
pub struct Calculator;
#[tool(tool_box)]
#[tool(tool_box, description = "A simple calculator")]
impl Calculator {
#[tool(description = "Calculate the sum of two numbers")]
fn sum(&self, #[tool(aggr)] SumRequest { a, b }: SumRequest) -> String {
#[tool(description = "Calculate the sum of two numbers", aggr)]
fn sum(&self, SumRequest { a, b }: SumRequest) -> String {
(a + b).to_string()
}

#[tool(description = "Calculate the sub of two numbers")]
fn sub(
&self,
#[tool(param)]
#[schemars(description = "the left hand side number")]
a: i32,
#[tool(param)]
#[schemars(description = "the right hand side number")]
b: i32,
#[schemars(description = "the left hand side number")] a: i32,
#[schemars(description = "the right hand side number")] b: i32,
) -> String {
(a - b).to_string()
}
}

#[tool(tool_box)]
impl ServerHandler for Calculator {
fn get_info(&self) -> ServerInfo {
ServerInfo {
instructions: Some("A simple calculator".into()),
capabilities: ServerCapabilities::builder().enable_tools().build(),
..Default::default()
}
}
}
7 changes: 2 additions & 5 deletions crates/rmcp/tests/test_complex_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,8 @@ impl Demo {
Self
}

#[tool(description = "LLM")]
async fn chat(
&self,
#[tool(aggr)] chat_request: ChatRequest,
) -> Result<CallToolResult, McpError> {
#[tool(description = "LLM", aggr)]
async fn chat(&self, chat_request: ChatRequest) -> Result<CallToolResult, McpError> {
let content = Content::json(chat_request)?;
Ok(CallToolResult::success(vec![content]))
}
Expand Down
2 changes: 1 addition & 1 deletion crates/rmcp/tests/test_tool_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub struct Server {}
impl Server {
/// This tool is used to get the weather of a city.
#[tool(name = "get-weather", description = "Get the weather of a city.", vis = )]
pub async fn get_weather(&self, #[tool(param)] city: String) -> String {
pub async fn get_weather(&self, city: String) -> String {
drop(city);
"rain".to_string()
}
Expand Down
Loading