Skip to content

Commit 34fc44d

Browse files
Merge pull request #1 from ggerganov/gg/grammar-refactor
llama : minor llama_grammar refactoring
2 parents 2aa6dd2 + 17b3a3e commit 34fc44d

File tree

5 files changed

+33
-51
lines changed

5 files changed

+33
-51
lines changed

examples/gbnf-validator/gbnf-validator.cpp

+4-8
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,15 @@
1111
static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
1212
const auto cpts = unicode_cpts_from_utf8(input_str);
1313

14-
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
15-
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
16-
llama_grammar_stacks_cache & stacks_cache = llama_grammar_get_stacks_cache(grammar);
14+
auto & stacks_cur = llama_grammar_get_stacks(grammar);
1715

1816
size_t pos = 0;
1917
for (const auto & cpt : cpts) {
20-
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
21-
22-
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur, stacks_cache);
18+
llama_grammar_accept(grammar, cpt);
2319

2420
if (stacks_cur.empty()) {
2521
error_pos = pos;
2622
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'";
27-
stacks_cur = stacks_prev;
2823
return false;
2924
}
3025
++pos;
@@ -83,7 +78,8 @@ int main(int argc, char** argv) {
8378

8479
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
8580
if (grammar == nullptr) {
86-
throw std::runtime_error("Failed to initialize llama_grammar");
81+
fprintf(stdout, "Failed to initialize llama_grammar\n");
82+
return 1;
8783
}
8884
// Read the input file
8985
std::string input_str;

src/llama-grammar.cpp

+19-23
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,7 @@ static void llama_grammar_advance_stack_memo(
764764
if (it != stacks_cache.end()) {
765765
advanced_stacks = it->second;
766766
} else {
767-
// Advance stacks with memorization
767+
// Advance stacks with memorization
768768
llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks, stacks_cache);
769769
stacks_cache.insert(make_pair(stack, advanced_stacks));
770770
}
@@ -917,20 +917,11 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
917917
return grammar->stacks;
918918
}
919919

920-
llama_grammar_stacks_cache & llama_grammar_get_stacks_cache(struct llama_grammar * grammar) {
921-
return grammar->stacks_cache;
922-
}
923-
924-
void llama_grammar_accept(
925-
const llama_grammar_rules & rules,
926-
const llama_grammar_stacks & stacks,
927-
const uint32_t chr,
928-
llama_grammar_stacks & stacks_new,
929-
llama_grammar_stacks_cache & stacks_cache) {
930-
stacks_new.clear();
931-
stacks_new.reserve(stacks.size());
920+
void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
921+
llama_grammar_stacks stacks_new;
922+
stacks_new.reserve(grammar->stacks.size());
932923

933-
for (const auto & stack : stacks) {
924+
for (const auto & stack : grammar->stacks) {
934925
if (stack.empty()) {
935926
continue;
936927
}
@@ -944,9 +935,11 @@ void llama_grammar_accept(
944935
if (!llama_grammar_is_end_of_sequence(pos)) {
945936
new_stack.push_back(pos);
946937
}
947-
llama_grammar_advance_stack_memo(rules, new_stack, stacks_new, stacks_cache);
938+
llama_grammar_advance_stack_memo(grammar->rules, new_stack, stacks_new, grammar->stacks_cache);
948939
}
949940
}
941+
942+
grammar->stacks = std::move(stacks_new);
950943
}
951944

952945
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
@@ -1062,7 +1055,7 @@ struct llama_grammar * llama_grammar_init_impl(
10621055
// Important: vec_rules has to be moved here, not copied, because stacks contains
10631056
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
10641057
// then the pointers would be invalidated when the local vec_rules goes out of scope.
1065-
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, std::move(stacks_cache), };
1058+
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), std::move(stacks_cache), {}, };
10661059
}
10671060

10681061
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
@@ -1141,7 +1134,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
11411134
// Important: vec_rules has to be moved here, not copied, because stacks contains
11421135
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
11431136
// then the pointers would be invalidated when the local vec_rules goes out of scope.
1144-
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, std::move(stacks_cache), };
1137+
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), std::move(stacks_cache), {}, };
11451138
}
11461139

11471140
void llama_grammar_free_impl(struct llama_grammar * grammar) {
@@ -1153,15 +1146,21 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
11531146
}
11541147

11551148
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
1156-
llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, };
1149+
llama_grammar * result = new llama_grammar {
1150+
grammar.vocab,
1151+
grammar.rules,
1152+
grammar.stacks,
1153+
grammar.stacks_cache,
1154+
grammar.partial_utf8,
1155+
};
11571156

11581157
// redirect elements in stacks to point to new rules
11591158
for (size_t is = 0; is < result->stacks.size(); is++) {
11601159
for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
11611160
for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
11621161
for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
11631162
if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
1164-
result->stacks[is][ie] = &result->rules[ir0][ir1];
1163+
result->stacks[is][ie] = &result->rules[ir0][ir1];
11651164
}
11661165
}
11671166
}
@@ -1228,11 +1227,8 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
12281227
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
12291228
const auto & code_points = decoded.first;
12301229

1231-
llama_grammar_stacks stacks_new;
1232-
12331230
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
1234-
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new, grammar.stacks_cache);
1235-
grammar.stacks = std::move(stacks_new);
1231+
llama_grammar_accept(&grammar, *it);
12361232
}
12371233

12381234
grammar.partial_utf8 = decoded.second;

src/llama-grammar.h

+5-9
Original file line numberDiff line numberDiff line change
@@ -71,20 +71,15 @@ struct VectorPointerHash {
7171

7272
using llama_grammar_stacks_cache = std::unordered_map<llama_grammar_stack, llama_grammar_stacks, VectorPointerHash>;
7373

74+
// TODO: remove, needed for tests atm
7475
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
7576
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
76-
llama_grammar_stacks_cache & llama_grammar_get_stacks_cache( struct llama_grammar * grammar);
7777

7878
// takes a set of possible pushdown stacks on a grammar, which are required to
7979
// be positioned at a character range (see `llama_grammar_advance_stack`), and
8080
// produces the N possible stacks if the given char is accepted at those
8181
// positions
82-
void llama_grammar_accept(
83-
const llama_grammar_rules & rules,
84-
const llama_grammar_stacks & stacks,
85-
uint32_t chr,
86-
llama_grammar_stacks & stacks_new,
87-
llama_grammar_stacks_cache & stacks_cache);
82+
void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr);
8883

8984
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
9085
const llama_grammar_rules & rules,
@@ -128,10 +123,11 @@ struct llama_grammar {
128123
const llama_grammar_rules rules; // TODO: shared ptr
129124
llama_grammar_stacks stacks;
130125

131-
// buffer for partially generated UTF-8 sequence from accepted tokens
132-
llama_partial_utf8 partial_utf8;
133126
// cache N possible stacks from a stack
134127
llama_grammar_stacks_cache stacks_cache;
128+
129+
// buffer for partially generated UTF-8 sequence from accepted tokens
130+
llama_partial_utf8 partial_utf8;
135131
};
136132

137133
//

tests/test-grammar-integration.cpp

+3-7
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,10 @@ static bool test_build_grammar_fails(const std::string & grammar_str) {
3232
static bool match_string(const std::string & input, llama_grammar * grammar) {
3333
const auto cpts = unicode_cpts_from_utf8(input);
3434

35-
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
36-
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
37-
llama_grammar_stacks_cache & stacks_cache = llama_grammar_get_stacks_cache(grammar);
35+
auto & stacks_cur = llama_grammar_get_stacks(grammar);
3836

3937
for (const auto & cpt : cpts) {
40-
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
41-
42-
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur, stacks_cache);
38+
llama_grammar_accept(grammar, cpt);
4339

4440
if (stacks_cur.empty()) {
4541
// no stacks means that the grammar failed to match at this point
@@ -64,7 +60,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
6460
auto * grammar = build_grammar(grammar_str);
6561

6662
// Save the original grammar stacks so that we can reset after every new string we want to test
67-
const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar);
63+
const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar); // copy
6864

6965
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
7066

tests/test-llama-grammar.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,10 @@ int main()
113113
}
114114
}
115115

116-
llama_grammar * grammar = NULL;
117116
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
118117

119-
grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
120-
if (grammar == nullptr)
121-
{
118+
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
119+
if (grammar == nullptr) {
122120
throw std::runtime_error("Failed to initialize llama_grammar");
123121
}
124122

0 commit comments

Comments
 (0)