Skip to content

Commit 8b60e3a

Browse files
ggerganovtinglou
authored andcommitted
llama : minor grammar refactor (ggml-org#10897)
ggml-ci
1 parent b43996a commit 8b60e3a

File tree

5 files changed

+26
-37
lines changed

5 files changed

+26
-37
lines changed

examples/gbnf-validator/gbnf-validator.cpp

+4-7
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,15 @@
1111
static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
1212
const auto cpts = unicode_cpts_from_utf8(input_str);
1313

14-
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
15-
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
14+
auto & stacks_cur = llama_grammar_get_stacks(grammar);
1615

1716
size_t pos = 0;
1817
for (const auto & cpt : cpts) {
19-
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
20-
21-
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
18+
llama_grammar_accept(grammar, cpt);
2219

2320
if (stacks_cur.empty()) {
2421
error_pos = pos;
2522
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'";
26-
stacks_cur = stacks_prev;
2723
return false;
2824
}
2925
++pos;
@@ -82,7 +78,8 @@ int main(int argc, char** argv) {
8278

8379
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
8480
if (grammar == nullptr) {
85-
throw std::runtime_error("Failed to initialize llama_grammar");
81+
fprintf(stdout, "Failed to initialize llama_grammar\n");
82+
return 1;
8683
}
8784
// Read the input file
8885
std::string input_str;

src/llama-grammar.cpp

+15-15
Original file line numberDiff line numberDiff line change
@@ -822,15 +822,11 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
822822
return grammar->stacks;
823823
}
824824

825-
void llama_grammar_accept(
826-
const llama_grammar_rules & rules,
827-
const llama_grammar_stacks & stacks,
828-
const uint32_t chr,
829-
llama_grammar_stacks & stacks_new) {
830-
stacks_new.clear();
831-
stacks_new.reserve(stacks.size());
825+
void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
826+
llama_grammar_stacks stacks_new;
827+
stacks_new.reserve(grammar->stacks.size());
832828

833-
for (const auto & stack : stacks) {
829+
for (const auto & stack : grammar->stacks) {
834830
if (stack.empty()) {
835831
continue;
836832
}
@@ -844,9 +840,11 @@ void llama_grammar_accept(
844840
if (!llama_grammar_is_end_of_sequence(pos)) {
845841
new_stack.push_back(pos);
846842
}
847-
llama_grammar_advance_stack(rules, new_stack, stacks_new);
843+
llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new);
848844
}
849845
}
846+
847+
grammar->stacks = std::move(stacks_new);
850848
}
851849

852850
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
@@ -1051,15 +1049,20 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
10511049
}
10521050

10531051
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
1054-
llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, };
1052+
llama_grammar * result = new llama_grammar {
1053+
grammar.vocab,
1054+
grammar.rules,
1055+
grammar.stacks,
1056+
grammar.partial_utf8,
1057+
};
10551058

10561059
// redirect elements in stacks to point to new rules
10571060
for (size_t is = 0; is < result->stacks.size(); is++) {
10581061
for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
10591062
for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
10601063
for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
10611064
if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
1062-
result->stacks[is][ie] = &result->rules[ir0][ir1];
1065+
result->stacks[is][ie] = &result->rules[ir0][ir1];
10631066
}
10641067
}
10651068
}
@@ -1126,11 +1129,8 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
11261129
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
11271130
const auto & code_points = decoded.first;
11281131

1129-
llama_grammar_stacks stacks_new;
1130-
11311132
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
1132-
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new);
1133-
grammar.stacks = std::move(stacks_new);
1133+
llama_grammar_accept(&grammar, *it);
11341134
}
11351135

11361136
grammar.partial_utf8 = decoded.second;

src/llama-grammar.h

+2-5
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,15 @@ using llama_grammar_rules = std::vector<llama_grammar_rule>;
5858
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
5959
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
6060

61+
// TODO: remove, needed for tests atm
6162
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
6263
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
6364

6465
// takes a set of possible pushdown stacks on a grammar, which are required to
6566
// be positioned at a character range (see `llama_grammar_advance_stack`), and
6667
// produces the N possible stacks if the given char is accepted at those
6768
// positions
68-
void llama_grammar_accept(
69-
const llama_grammar_rules & rules,
70-
const llama_grammar_stacks & stacks,
71-
uint32_t chr,
72-
llama_grammar_stacks & stacks_new);
69+
void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr);
7370

7471
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
7572
const llama_grammar_rules & rules,

tests/test-grammar-integration.cpp

+3-6
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,10 @@ static bool test_build_grammar_fails(const std::string & grammar_str) {
3232
static bool match_string(const std::string & input, llama_grammar * grammar) {
3333
const auto cpts = unicode_cpts_from_utf8(input);
3434

35-
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
36-
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
35+
auto & stacks_cur = llama_grammar_get_stacks(grammar);
3736

3837
for (const auto & cpt : cpts) {
39-
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
40-
41-
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
38+
llama_grammar_accept(grammar, cpt);
4239

4340
if (stacks_cur.empty()) {
4441
// no stacks means that the grammar failed to match at this point
@@ -63,7 +60,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
6360
auto * grammar = build_grammar(grammar_str);
6461

6562
// Save the original grammar stacks so that we can reset after every new string we want to test
66-
const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar);
63+
const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar); // copy
6764

6865
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
6966

tests/test-llama-grammar.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,10 @@ int main()
113113
}
114114
}
115115

116-
llama_grammar * grammar = NULL;
117116
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
118117

119-
grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
120-
if (grammar == nullptr)
121-
{
118+
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
119+
if (grammar == nullptr) {
122120
throw std::runtime_error("Failed to initialize llama_grammar");
123121
}
124122

0 commit comments

Comments
 (0)