-
Notifications
You must be signed in to change notification settings - Fork 96
/
Copy pathlib.rs
152 lines (128 loc) · 5.19 KB
/
lib.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
use convert_case::{Case, Casing};
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use std::collections::HashMap;
use syn::{
parse::Parse, parse::ParseStream, parse_macro_input, punctuated::Punctuated, Expr, ExprLit,
FnArg, ItemFn, Lit, Meta, Pat, PatType, Token,
};
struct MacroArgs {
name: Option<String>,
description: Option<String>,
param_descriptions: HashMap<String, String>,
}
impl Parse for MacroArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut name = None;
let mut description = None;
let mut param_descriptions = HashMap::new();
let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(input)?;
for meta in meta_list {
match meta {
Meta::NameValue(nv) => {
let ident = nv.path.get_ident().unwrap().to_string();
if let Expr::Lit(ExprLit {
lit: Lit::Str(lit_str),
..
}) = nv.value
{
match ident.as_str() {
"name" => name = Some(lit_str.value()),
"description" => description = Some(lit_str.value()),
_ => {}
}
}
}
Meta::List(list) if list.path.is_ident("params") => {
let nested: Punctuated<Meta, Token![,]> =
list.parse_args_with(Punctuated::parse_terminated)?;
for meta in nested {
if let Meta::NameValue(nv) = meta {
if let Expr::Lit(ExprLit {
lit: Lit::Str(lit_str),
..
}) = nv.value
{
let param_name = nv.path.get_ident().unwrap().to_string();
param_descriptions.insert(param_name, lit_str.value());
}
}
}
}
_ => {}
}
}
Ok(MacroArgs {
name,
description,
param_descriptions,
})
}
}
#[proc_macro_attribute]
pub fn tool(args: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(args as MacroArgs);
let input_fn = parse_macro_input!(input as ItemFn);
// Extract function details
let fn_name = &input_fn.sig.ident;
let fn_name_str = fn_name.to_string();
// Generate PascalCase struct name from the function name
let struct_name = format_ident!("{}", { fn_name_str.to_case(Case::Pascal) });
// Use provided name or function name as default
let tool_name = args.name.unwrap_or(fn_name_str);
let tool_description = args.description.unwrap_or_default();
// Extract parameter names, types, and descriptions
let mut param_defs = Vec::new();
let mut param_names = Vec::new();
for arg in input_fn.sig.inputs.iter() {
if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
if let Pat::Ident(param_ident) = &**pat {
let param_name = ¶m_ident.ident;
let param_name_str = param_name.to_string();
let description = args
.param_descriptions
.get(¶m_name_str)
.map(|s| s.as_str())
.unwrap_or("");
param_names.push(param_name);
param_defs.push(quote! {
#[schemars(description = #description)]
#param_name: #ty
});
}
}
}
// Generate the implementation
let params_struct_name = format_ident!("{}Parameters", struct_name);
let expanded = quote! {
#[derive(serde::Deserialize, schemars::JsonSchema)]
struct #params_struct_name {
#(#param_defs,)*
}
#input_fn
#[derive(Default)]
struct #struct_name;
#[async_trait::async_trait]
impl mcp_spec::handler::ToolHandler for #struct_name {
fn name(&self) -> &'static str {
#tool_name
}
fn description(&self) -> &'static str {
#tool_description
}
fn schema(&self) -> serde_json::Value {
mcp_spec::handler::generate_schema::<#params_struct_name>()
.expect("Failed to generate schema")
}
async fn call(&self, params: serde_json::Value) -> Result<serde_json::Value, mcp_spec::handler::ToolError> {
let params: #params_struct_name = serde_json::from_value(params)
.map_err(|e| mcp_spec::handler::ToolError::InvalidParameters(e.to_string()))?;
// Extract parameters and call the function
let result = #fn_name(#(params.#param_names,)*).await
.map_err(|e| mcp_spec::handler::ToolError::ExecutionError(e.to_string()))?;
Ok(serde_json::to_value(result).expect("should serialize"))
}
}
};
TokenStream::from(expanded)
}