Skip to content

tool-call: fix Qwen 2.5 Coder support, add micro benchmarks, support trigger patterns for lazy grammars #12034

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 48 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
b37779b
sampler: turn lazy grammar trigger words to regexes
ochafik Feb 21, 2025
a456911
add scripts/tool_bench.sh & .py
ochafik Feb 21, 2025
14a4388
optionally allow any spaces in json schema grammars (useful for llama…
ochafik Feb 22, 2025
e2ca8be
constrain llama json output regardless of function name if matches at…
ochafik Feb 22, 2025
53266f9
better error when wrong function called
ochafik Feb 22, 2025
7833c16
improve error message in weather test
ochafik Feb 22, 2025
0e1a00e
add more models to tool_bench.sh
ochafik Feb 22, 2025
44740f7
benchmark other sizes of qwen 2.5 coder
ochafik Feb 23, 2025
dd6eb97
rm duplicate in tool_bench.sh
ochafik Feb 23, 2025
0fc6218
add missing <variant> include
ochafik Feb 23, 2025
6fd4972
fix lints
ochafik Feb 23, 2025
2e656f9
improve "bad" qwen triggers
ochafik Feb 23, 2025
fbd3c19
add cast to please some gccs
ochafik Feb 23, 2025
62a1416
ditch server test request retry logic
ochafik Feb 23, 2025
596ff7f
fix flake8 lints
ochafik Feb 23, 2025
fe6968f
nits
ochafik Feb 23, 2025
1caacd5
remove any_spaces grammar option, allow extra line for airy llama jso…
ochafik Feb 23, 2025
789a3e1
Update test_tool_call.py
ochafik Feb 23, 2025
6493a14
test w/ beefier qwen 2.5 coder 3b
ochafik Feb 23, 2025
cc817a0
revert some test_hello_world diffs
ochafik Feb 23, 2025
ead02c6
diff
ochafik Feb 23, 2025
d7acf2c
Update test_tool_call.py
ochafik Feb 23, 2025
0db4073
add requirements for tool_bench
ochafik Feb 23, 2025
0ce606b
fix test_thoughts deepseek test expectation
ochafik Feb 23, 2025
a3cde16
Update README.md
ochafik Feb 23, 2025
79ad623
update relaxed newline space rule in grammar tests
ochafik Feb 23, 2025
3fe208a
support add_generation_prompt query parameter (useful for /apply_temp…
ochafik Feb 25, 2025
fe8c79b
Merge remote-tracking branch 'origin/master' into tool-bench-prod
ochafik Feb 25, 2025
99d2d80
token cast tweak for gcc
ochafik Feb 25, 2025
c7fa19a
fix warning on gcc13 w/ uninitialized variant
ochafik Feb 25, 2025
6e5a830
fix python lints
ochafik Feb 25, 2025
0b5d105
fix gcc13 warning
ochafik Feb 25, 2025
7bcc5af
fix pyright lints in tool_bench.py
ochafik Feb 25, 2025
d1f48d0
Merge remote-tracking branch 'origin/master' into tool-bench-prod
ochafik Feb 25, 2025
fc19192
update readme w/ link to tool call
ochafik Feb 27, 2025
60f28ef
tool-bench: add --ctk, --ctv, --fa flags
ochafik Feb 27, 2025
2470a1c
Merge remote-tracking branch 'origin/master' into tool-bench-prod
ochafik Mar 4, 2025
e6e9c13
common_grammar_trigger: always use string value (+ optional token)
ochafik Feb 27, 2025
5d43b72
add llama_grammar_trigger_pattern
ochafik Mar 4, 2025
1317a35
add common_grammar_trigger.{to_json,from_json}
ochafik Mar 5, 2025
ad3caa3
fix crashing typo
ochafik Mar 5, 2025
a6d7887
avoid returning optional from parse_json
ochafik Mar 5, 2025
20a2f5f
disable slow hello Llama-3.1-8B (chopped unescaped string witin strin…
ochafik Mar 5, 2025
92e9723
fix nit eol at eof
ochafik Mar 5, 2025
01be080
Update src/llama-grammar.cpp
ochafik Mar 5, 2025
00db465
Merge remote-tracking branch 'origin/master' into tool-bench-prod
ochafik Mar 5, 2025
24010fe
avoid ggml_assert in server for grammar triggers inconsistency
ochafik Mar 5, 2025
71719a6
add comment on limits to common_grammar_trigger.to/from json speciali…
ochafik Mar 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)

- **How to use [MTLResidencySet](https://developer.apple.com/documentation/metal/mtlresidencyset?language=objc) to keep the GPU memory active?** https://github.com/ggml-org/llama.cpp/pull/11427
- **VS Code extension for FIM completions:** https://github.com/ggml-org/llama.vscode
- Universal tool call support in `llama-server`: https://github.com/ggml-org/llama.cpp/pull/9639
- Universal [tool call support](./docs/function-calling.md) in `llama-server` https://github.com/ggml-org/llama.cpp/pull/9639
- Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim
- Introducing GGUF-my-LoRA https://github.com/ggml-org/llama.cpp/discussions/10123
- Hugging Face Inference Endpoints now support GGUF out of the box! https://github.com/ggml-org/llama.cpp/discussions/9669
Expand Down
439 changes: 300 additions & 139 deletions common/chat.cpp

Large diffs are not rendered by default.

28 changes: 27 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
#include "json-schema-to-grammar.h"
#include "llama.h"

#include <algorithm>
Expand Down Expand Up @@ -483,6 +482,11 @@ void string_replace_all(std::string & s, const std::string & search, const std::
s = std::move(builder);
}

std::string regex_escape(const std::string & s) {
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
return std::regex_replace(s, special_chars, "\\$0");
}

std::string string_join(const std::vector<std::string> & values, const std::string & separator) {
std::ostringstream result;
for (size_t i = 0; i < values.size(); ++i) {
Expand Down Expand Up @@ -2026,3 +2030,25 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
return result;
}

template <>
json common_grammar_trigger::to_json() const {
json out {
{"type", (int) type},
{"value", value},
};
if (type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
out["token"] = (int) token;
}
return out;
}

template <>
common_grammar_trigger common_grammar_trigger::from_json(const json & in) {
common_grammar_trigger out;
out.type = (common_grammar_trigger_type) in.at("type").get<int>();
out.value = in.at("value").get<std::string>();
if (out.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
out.token = (llama_token) in.at("token").get<int>();
}
return out;
}
21 changes: 17 additions & 4 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,21 @@ enum common_conversation_mode {
COMMON_CONVERSATION_MODE_AUTO = 2,
};

enum common_grammar_trigger_type {
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
};

struct common_grammar_trigger {
std::string word;
bool at_start;
common_grammar_trigger_type type;
std::string value;
llama_token token = LLAMA_TOKEN_NULL;

// T can only be nlohmann::ordered_json
template <class T> T to_json() const;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm ok I didn't notice that we cannot include json in this file. Then maybe change it to:

template <class T> T to() const;

Then use with to<json> ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While it does read better in the call site, I think naming it just to makes the interface harder to understand to readers, esp. given the unconventional template use (if anything, would name it serialize / deserialize, or provide operator<< / operator>>). Happy to revisit in a follow up / to batch update all the to_json* to something :-)

template <class T> static common_grammar_trigger from_json(const T & in);
};

// sampling parameters
Expand Down Expand Up @@ -163,8 +175,7 @@ struct common_params_sampling {

std::string grammar; // optional BNF-like grammar to constrain sampling
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_trigger_words; // optional trigger words to trigger lazy grammar
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens.
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
std::set<llama_token> preserved_tokens;

std::vector<llama_logit_bias> logit_bias; // logit biases to apply
Expand Down Expand Up @@ -458,6 +469,8 @@ std::string string_repeat(const std::string & str, size_t n);

void string_replace_all(std::string & s, const std::string & search, const std::string & replace);

std::string regex_escape(const std::string & s);

template<class T>
static std::vector<T> string_split(const std::string & str, char delim) {
static_assert(!std::is_same<T, std::string>::value, "Please use the specialized version for std::string");
Expand Down
9 changes: 4 additions & 5 deletions common/json-schema-to-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
throw std::runtime_error("At least one of min_value or max_value must be set");
}

const std::string SPACE_RULE = "| \" \" | \"\\n\" [ \\t]{0,20}";
const std::string SPACE_RULE = "| \" \" | \"\\n\"{1,2} [ \\t]{0,20}";

struct BuiltinRule {
std::string content;
Expand Down Expand Up @@ -764,11 +764,10 @@ class SchemaConverter {
public:
SchemaConverter(
const std::function<json(const std::string &)> & fetch_json,
bool dotall,
bool compact_spaces)
bool dotall)
: _fetch_json(fetch_json), _dotall(dotall)
{
_rules["space"] = compact_spaces ? "\" \"?" : SPACE_RULE;
_rules["space"] = SPACE_RULE;
}

void resolve_refs(json & schema, const std::string & url) {
Expand Down Expand Up @@ -1007,7 +1006,7 @@ std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
}

std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall, options.compact_spaces);
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall);
common_grammar_builder builder {
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
return converter._add_rule(name, rule);
Expand Down
1 change: 0 additions & 1 deletion common/json-schema-to-grammar.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ struct common_grammar_builder {

struct common_grammar_options {
bool dotall = false;
bool compact_spaces = false;
};

std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});
51 changes: 44 additions & 7 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,16 +160,53 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE
} else {
std::vector<const char *> trigger_words;
trigger_words.reserve(params.grammar_trigger_words.size());
for (const auto & str : params.grammar_trigger_words) {
trigger_words.push_back(str.word.c_str());
std::vector<std::string> patterns_at_start;
std::vector<std::string> patterns_anywhere;
std::vector<llama_token> trigger_tokens;
for (const auto & trigger : params.grammar_triggers) {
switch (trigger.type) {
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
{
const auto & word = trigger.value;
patterns_anywhere.push_back(regex_escape(word));
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
{
const auto & pattern = trigger.value;
(trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
{
const auto token = trigger.token;
trigger_tokens.push_back(token);
break;
}
default:
GGML_ASSERT(false && "unknown trigger type");
}
}

std::vector<std::string> trigger_patterns;
if (!patterns_at_start.empty()) {
trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*");
}
if (!patterns_anywhere.empty()) {
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
}

std::vector<const char *> trigger_patterns_c;
trigger_patterns_c.reserve(trigger_patterns.size());
for (const auto & regex : trigger_patterns) {
trigger_patterns_c.push_back(regex.c_str());
}

grmr = params.grammar_lazy
? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root",
trigger_words.data(), trigger_words.size(),
params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size())
? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
trigger_patterns_c.data(), trigger_patterns_c.size(),
trigger_tokens.data(), trigger_tokens.size())
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
}

Expand Down
2 changes: 1 addition & 1 deletion examples/json_schema_to_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def __init__(self, content: str, deps: list | None = None):
self.deps = deps or []

# Constraining spaces to prevent model "running away".
SPACE_RULE = '| " " | "\\n" [ \\t]{0,20}'
SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}'

PRIMITIVE_RULES = {
'boolean' : BuiltinRule('("true" | "false") space', []),
Expand Down
2 changes: 1 addition & 1 deletion examples/server/public_legacy/json-schema-to-grammar.mjs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// WARNING: This file was ported from json_schema_to_grammar.py, please fix bugs / add features there first.
const SPACE_RULE = '| " " | "\\n" [ \\t]{0,20}';
const SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}';

function _buildRepetition(itemRule, minItems, maxItems, opts={}) {
if (minItems === 0 && maxItems === 1) {
Expand Down
62 changes: 35 additions & 27 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ struct slot_params {
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
}

std::vector<std::string> grammar_trigger_words;
for (const auto & trigger : sampling.grammar_trigger_words) {
grammar_trigger_words.push_back(trigger.word);
auto grammar_triggers = json::array();
for (const auto & trigger : sampling.grammar_triggers) {
grammar_triggers.push_back(trigger.to_json<json>());
}

return json {
Expand Down Expand Up @@ -170,8 +170,8 @@ struct slot_params {
{"n_probs", sampling.n_probs},
{"min_keep", sampling.min_keep},
{"grammar", sampling.grammar},
{"grammar_trigger_words", grammar_trigger_words},
{"grammar_trigger_tokens", sampling.grammar_trigger_tokens},
{"grammar_lazy", sampling.grammar_lazy},
{"grammar_triggers", grammar_triggers},
{"preserved_tokens", sampling.preserved_tokens},
{"chat_format", common_chat_format_name(oaicompat_chat_format)},
{"samplers", samplers},
Expand Down Expand Up @@ -356,24 +356,6 @@ struct server_task {
}

{
const auto grammar_triggers = data.find("grammar_triggers");
if (grammar_triggers != data.end()) {
for (const auto & t : *grammar_triggers) {
common_grammar_trigger trigger;
trigger.word = t.at("word");
trigger.at_start = t.at("at_start");

auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
SRV_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
params.sampling.preserved_tokens.insert(ids[0]);
continue;
}
SRV_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
params.sampling.grammar_trigger_words.push_back(trigger);
}
}
const auto preserved_tokens = data.find("preserved_tokens");
if (preserved_tokens != data.end()) {
for (const auto & t : *preserved_tokens) {
Expand All @@ -383,12 +365,38 @@ struct server_task {
params.sampling.preserved_tokens.insert(ids[0]);
} else {
// This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
SRV_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get<std::string>().c_str());
SRV_DBG("Not preserved because more than 1 token: %s\n", t.get<std::string>().c_str());
}
}
}
const auto grammar_triggers = data.find("grammar_triggers");
if (grammar_triggers != data.end()) {
for (const auto & t : *grammar_triggers) {
auto ct = common_grammar_trigger::from_json(t);
if (ct.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
const auto & word = ct.value;
auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
auto token = ids[0];
if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) {
throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word);
}
SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str());
common_grammar_trigger trigger;
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
trigger.value = (llama_token) token;
params.sampling.grammar_triggers.push_back(trigger);
} else {
SRV_DBG("Grammar trigger word: `%s`\n", word.c_str());
params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
}
} else {
params.sampling.grammar_triggers.push_back(ct);
}
}
}
if (params.sampling.grammar_lazy) {
GGML_ASSERT(params.sampling.grammar_trigger_tokens.size() > 0 || params.sampling.grammar_trigger_words.size() > 0);
if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) {
throw std::runtime_error("Error: no triggers set for lazy grammar!");
}
}

Expand Down Expand Up @@ -2045,7 +2053,7 @@ struct server_context {

if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
// Might be better to reject the request with a 400 ?
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.params.n_predict, slot.n_predict);
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d\n", slot.params.n_predict, slot.n_predict);
slot.params.n_predict = slot.n_predict;
}

Expand Down
Loading
Loading