Skip to content

Commit cbaadc9

Browse files
authored
grammars: 1.5x faster inference w/ complex grammars (vector reserves / reuses) (ggml-org#6609)
* grammars: reserve rejects & next candidates * grammars: reuse new_stacks * grammars: fix missing sig change in llama.h * grammars: fix test (api changed) * grammars: update gbnf-validator.cpp * grammars: simpler syntax (no swap)
1 parent 1bbdaf6 commit cbaadc9

File tree

4 files changed

+17
-12
lines changed

4 files changed

+17
-12
lines changed

examples/gbnf-validator/gbnf-validator.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ static bool llama_sample_grammar_string(struct llama_grammar * grammar, const st
1717
size_t pos = 0;
1818
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
1919
auto prev_stacks = grammar->stacks;
20-
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
20+
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
2121
if (grammar->stacks.empty()) {
2222
error_pos = pos;
2323
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";

llama.cpp

+10-6
Original file line numberDiff line numberDiff line change
@@ -11912,12 +11912,13 @@ static void llama_grammar_advance_stack(
1191211912
// be positioned at a character range (see `llama_grammar_advance_stack`), and
1191311913
// produces the N possible stacks if the given char is accepted at those
1191411914
// positions
11915-
std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
11915+
void llama_grammar_accept(
1191611916
const std::vector<std::vector<llama_grammar_element>> & rules,
1191711917
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
11918-
const uint32_t chr) {
11918+
const uint32_t chr,
11919+
std::vector<std::vector<const llama_grammar_element *>> & new_stacks) {
1191911920

11920-
std::vector<std::vector<const llama_grammar_element *>> new_stacks;
11921+
new_stacks.clear();
1192111922

1192211923
for (const auto & stack : stacks) {
1192311924
if (stack.empty()) {
@@ -11936,8 +11937,6 @@ std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
1193611937
llama_grammar_advance_stack(rules, new_stack, new_stacks);
1193711938
}
1193811939
}
11939-
11940-
return new_stacks;
1194111940
}
1194211941

1194311942
static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
@@ -11951,6 +11950,7 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
1195111950
const std::vector<llama_grammar_candidate> & candidates) {
1195211951

1195311952
std::vector<llama_grammar_candidate> rejects;
11953+
rejects.reserve(candidates.size());
1195411954

1195511955
if (stack.empty()) {
1195611956
for (const auto & tok : candidates) {
@@ -11964,6 +11964,8 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
1196411964
const llama_grammar_element * stack_pos = stack.back();
1196511965

1196611966
std::vector<llama_grammar_candidate> next_candidates;
11967+
next_candidates.reserve(candidates.size());
11968+
1196711969
for (const auto & tok : candidates) {
1196811970
if (*tok.code_points == 0) {
1196911971
// reached end of full codepoints in token, reject iff it ended in a partial sequence
@@ -12771,8 +12773,10 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
1277112773
// Note terminating 0 in decoded string
1277212774
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
1277312775
const auto & code_points = decoded.first;
12776+
std::vector<std::vector<const llama_grammar_element *>> tmp_new_stacks;
1277412777
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
12775-
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
12778+
llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
12779+
grammar->stacks = tmp_new_stacks;
1277612780
}
1277712781
grammar->partial_utf8 = decoded.second;
1277812782
GGML_ASSERT(!grammar->stacks.empty());

llama.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -1097,10 +1097,11 @@ const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal
10971097
struct llama_context * ctx
10981098
);
10991099

1100-
std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
1100+
void llama_grammar_accept(
11011101
const std::vector<std::vector<llama_grammar_element>> & rules,
11021102
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
1103-
const uint32_t chr);
1103+
const uint32_t chr,
1104+
std::vector<std::vector<const llama_grammar_element *>> & new_stacks);
11041105

11051106
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
11061107
const std::string & src,

tests/test-grammar-integration.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ number ::= [0-9]+)""";
3838

3939
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
4040
auto prev_stacks = grammar->stacks;
41-
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
41+
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
4242
assert(!grammar->stacks.empty());
4343
}
4444

@@ -138,7 +138,7 @@ ws ::= [ \t\n\r]?)""";
138138
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
139139
++pos;
140140
auto prev_stacks = grammar->stacks;
141-
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
141+
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
142142

143143
// Expect that each code point will not cause the grammar to fail
144144
if (grammar->stacks.empty()) {
@@ -173,7 +173,7 @@ ws ::= [ \t\n\r]?)""";
173173

174174
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
175175
auto prev_stacks = grammar->stacks;
176-
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
176+
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
177177
if (grammar->stacks.empty()) {
178178
parse_failed = true;
179179
break;

0 commit comments

Comments
 (0)