Skip to content

Commit c8843da

Browse files
committed
Use format to extract toolcalls
1 parent 787fa89 commit c8843da

File tree

3 files changed

+55
-73
lines changed

3 files changed

+55
-73
lines changed

examples/main/main.cpp

+25-55
Original file line numberDiff line numberDiff line change
@@ -106,23 +106,39 @@ class chat_formatter {
106106

107107
std::string operator () (const std::string & role, const std::string & content, [[maybe_unused]] bool use_toolcalls = false) {
108108

109-
common_chat_msg new_msg;
110-
new_msg.role = role;
111-
new_msg.content = content;
112-
113-
common_chat_params cparams;
114109
common_chat_templates_inputs cinputs;
110+
cinputs.use_jinja = params_.use_jinja;
111+
cinputs.add_generation_prompt = (role == "user");
115112
#ifdef LLAMA_USE_TOOLCALL
116113
if (tc_client_ != nullptr && use_toolcalls) {
117114
cinputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tc_client_->tool_choice());
118115
cinputs.tools = common_chat_tools_parse_oaicompat(tc_client_->tool_list());
119116
}
120117
#endif
121-
bool add_ass = role == "user";
122-
auto formatted =
123-
common_chat_format_single(chat_templates_, chat_msgs_, new_msg, add_ass, params_.use_jinja,
124-
&cinputs, &cparams);
118+
for (const auto & msg : chat_msgs_) {
119+
cinputs.messages.push_back(common_chat_msg(msg));
120+
}
121+
122+
common_chat_msg new_msg = common_chat_parse(content, *chat_format_);
123+
new_msg.role = role;
124+
125+
if (! new_msg.tool_calls.empty()) {
126+
nlohmann::json result_array = nlohmann::json::array();
127+
for (const auto & tc : new_msg.tool_calls) {
128+
toolcall::result_set res = tc_client_->call(tc.name, tc.arguments, tc.id);
129+
if (! res.empty()) {
130+
for (const auto & r : res) {
131+
result_array.push_back(r.data);
132+
}
133+
}
134+
}
135+
new_msg.content += result_array.dump(-1);
136+
}
137+
138+
cinputs.messages.push_back(new_msg);
139+
common_chat_params cparams = common_chat_templates_apply(chat_templates_, cinputs);
125140

141+
auto formatted = cparams.prompt;
126142
chat_msgs_.push_back(new_msg);
127143
LOG_DBG("formatted: '%s'\n", formatted.c_str());
128144

@@ -145,42 +161,6 @@ class chat_formatter {
145161
#endif
146162
};
147163

148-
#ifdef LLAMA_USE_TOOLCALL
149-
static bool call_tool(common_chat_format chat_format, const std::string & assistant_msg, llama_context * ctx,
150-
toolcall::client::ptr tc_client, std::vector<llama_token> & embd_inp)
151-
{
152-
bool tool_was_called = false;
153-
common_chat_msg msg = common_chat_parse(assistant_msg, chat_format);
154-
if (! msg.tool_calls.empty()) {
155-
for (const auto & tc : msg.tool_calls) {
156-
nlohmann::json tc_oai_json {
157-
{"type", "function"},
158-
{"function", {
159-
{"name", tc.name},
160-
{"arguments", tc.arguments},
161-
}},
162-
{"id", tc.id},
163-
};
164-
toolcall::result_set res = tc_client->call(tc_oai_json);
165-
if (! res.empty()) {
166-
std::string toolcall_result_str;
167-
for (const auto & r : res) {
168-
toolcall_result_str += ("\n" + r.data); // Although more complex results can be
169-
// returned (resources, images, etc.),
170-
// for now simply append the data. Later
171-
// on support for specific models may
172-
// allow for unpacking Base64 data.
173-
}
174-
auto toolcall_result_tok = common_tokenize(ctx, toolcall_result_str, false, true);
175-
embd_inp.insert(embd_inp.end(), toolcall_result_tok.begin(), toolcall_result_tok.end());
176-
}
177-
tool_was_called = true;
178-
}
179-
}
180-
return tool_was_called;
181-
}
182-
#endif
183-
184164
int main(int argc, char ** argv) {
185165
common_params params;
186166
g_params = &params;
@@ -943,16 +923,6 @@ int main(int argc, char ** argv) {
943923
}
944924
}
945925

946-
#ifdef LLAMA_USE_TOOLCALL
947-
if ((tc_client && n_past > 0) && (waiting_for_first_input || is_interacting)) {
948-
size_t last_len = embd_inp.size();
949-
bool was_toolcall = call_tool(chat_format, assistant_ss.str(), ctx, tc_client, embd_inp);
950-
if (was_toolcall && last_len < embd_inp.size()) {
951-
LOG("%s", common_token_to_piece(ctx, embd_inp[last_len]).c_str());
952-
}
953-
}
954-
#endif
955-
956926
if ((n_past > 0 || waiting_for_first_input) && is_interacting) {
957927
LOG_DBG("waiting for user input\n");
958928

toolcall/client.cpp

+16-13
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@ bool toolcall::client::tool_list_dirty() const {
3939
return impl_->tool_list_dirty();
4040
}
4141

42-
toolcall::result_set toolcall::client::call(const std::string & request) {
43-
return impl_->call(request);
42+
toolcall::result_set toolcall::client::call(const std::string & name,
43+
const std::string & arguments,
44+
const std::string & id) {
45+
return impl_->call(name, arguments, id);
4446
}
4547

4648
const std::string & toolcall::client::tool_choice() const {
@@ -180,15 +182,6 @@ std::string toolcall::mcp_impl::tool_list() {
180182
return tools_;
181183
}
182184

183-
static mcp::tools_call_request tools_call_request_from_local_json(nlohmann::json id, const std::string & local_json) {
184-
nlohmann::json j = json::parse(local_json);
185-
mcp::tool_arg_list args;
186-
for (const auto & [key, val] : j["parameters"].items()) {
187-
args.push_back({key, val});
188-
}
189-
return mcp::tools_call_request(id, j["name"], args);
190-
}
191-
192185
static toolcall::result_set tools_call_response_to_result(const mcp::tools_call_response & resp) {
193186
toolcall::result_set result;
194187
for (const auto & res : resp.tool_result()) {
@@ -199,7 +192,10 @@ static toolcall::result_set tools_call_response_to_result(const mcp::tools_call_
199192
return std::move(result);
200193
}
201194

202-
toolcall::result_set toolcall::mcp_impl::call(const std::string & request) {
195+
toolcall::result_set toolcall::mcp_impl::call(const std::string & name,
196+
const std::string & arguments,
197+
const std::string & id)
198+
{
203199
using on_response = toolcall::callback<mcp::tools_call_response>;
204200

205201
if (transport_ == nullptr) {
@@ -213,7 +209,14 @@ toolcall::result_set toolcall::mcp_impl::call(const std::string & request) {
213209
response = tools_call_response_to_result(resp);
214210
tools_populating_.notify_one();
215211
};
216-
transport_->send(tools_call_request_from_local_json(next_id_++, request), set_response);
212+
std::string req_id = id.empty() ? std::to_string(next_id_++) : id;
213+
mcp::tool_arg_list req_args;
214+
auto json_args = json::parse(arguments); // TODO check errors
215+
for (const auto & [key, val] : json_args.items()) {
216+
req_args.push_back({key, val});
217+
}
218+
219+
transport_->send(mcp::tools_call_request(req_id, name, req_args), set_response);
217220
tools_populating_.wait_for(lock, std::chrono::seconds(15), [&response] { return ! response.empty(); });
218221

219222
return response;

toolcall/toolcall-client.h

+14-5
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ namespace toolcall
2727

2828
client(std::unique_ptr<client_impl> impl) : impl_(std::move(impl)) {}
2929

30-
result_set call(const std::string & request);
30+
result_set call(const std::string & name,
31+
const std::string & arguments,
32+
const std::string & id = "");
3133

3234
std::string tool_list();
3335
bool tool_list_dirty() const;
@@ -55,7 +57,9 @@ namespace toolcall
5557
return tool_list_dirty_;
5658
}
5759

58-
virtual result_set call(const std::string & request) = 0;
60+
virtual result_set call(const std::string & name,
61+
const std::string & arguments,
62+
const std::string & id = "") = 0;
5963

6064
const std::string & tool_choice() const { return tool_choice_; }
6165

@@ -76,9 +80,11 @@ namespace toolcall
7680
return tools_;
7781
}
7882

79-
virtual result_set call(const std::string & request) override {
83+
virtual result_set call(const std::string & /* name */,
84+
const std::string & /* arguments */,
85+
const std::string & /* id = "" */) override {
8086
return result_set {
81-
{"text", request, "text/plain", std::nullopt, false}
87+
{"text", "", "text/plain", std::nullopt, false}
8288
};
8389
}
8490

@@ -93,7 +99,10 @@ namespace toolcall
9399
mcp_impl(std::vector<std::string> argv, std::string tool_choice);
94100

95101
virtual std::string tool_list() override;
96-
virtual result_set call(const std::string & request) override;
102+
103+
virtual result_set call(const std::string & name,
104+
const std::string & arguments,
105+
const std::string & id = "") override;
97106

98107
virtual void initialize() override;
99108

0 commit comments

Comments
 (0)