Skip to content

Commit 6a83ecb

Browse files
committed
Revert " llama : adds llama-grammar memoization stacks (ggml-org#4218) ggml-org#9833"
This reverts commit 4cbf5c392af62252a69e17143e8a81d771ca6f8a.
1 parent e78bed5 commit 6a83ecb

File tree

3 files changed

+31
-139
lines changed

3 files changed

+31
-139
lines changed

examples/gbnf-validator/gbnf-validator.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,19 @@
1111
static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
1212
const auto cpts = unicode_cpts_from_utf8(input_str);
1313

14-
auto & stacks_cur = llama_grammar_get_stacks(grammar);
14+
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
15+
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
1516

1617
size_t pos = 0;
1718
for (const auto & cpt : cpts) {
18-
llama_grammar_accept(grammar, cpt);
19+
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
20+
21+
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
1922

2023
if (stacks_cur.empty()) {
2124
error_pos = pos;
2225
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'";
26+
stacks_cur = stacks_prev;
2327
return false;
2428
}
2529
++pos;
@@ -78,8 +82,7 @@ int main(int argc, char** argv) {
7882

7983
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
8084
if (grammar == nullptr) {
81-
fprintf(stdout, "Failed to initialize llama_grammar\n");
82-
return 1;
85+
throw std::runtime_error("Failed to initialize llama_grammar");
8386
}
8487
// Read the input file
8588
std::string input_str;

src/llama-grammar.cpp

Lines changed: 19 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -682,101 +682,6 @@ static bool llama_grammar_match_partial_char(
682682
return !is_positive_char;
683683
}
684684

685-
// transforms a grammar pushdown stack into N possible stacks, all ending
686-
// at a character range (terminal element)
687-
// additionally memoizes the stack to its possible stacks by mapping
688-
// < llama_grammar_stack, llama_grammar_stacks >
689-
690-
static void llama_grammar_advance_stack_memo(
691-
const llama_grammar_rules & rules,
692-
const llama_grammar_stack & stack,
693-
llama_grammar_stacks & new_stacks,
694-
llama_grammar_stacks_cache & stacks_cache);
695-
696-
static void llama_grammar_advance_stack_memo_impl(
697-
const llama_grammar_rules & rules,
698-
const llama_grammar_stack & stack,
699-
llama_grammar_stacks & new_stacks,
700-
llama_grammar_stacks_cache & stacks_cache) {
701-
if (stack.empty()) {
702-
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
703-
new_stacks.emplace_back(stack);
704-
}
705-
return;
706-
}
707-
708-
const llama_grammar_element * pos = stack.back();
709-
710-
switch (pos->type) {
711-
case LLAMA_GRETYPE_RULE_REF: {
712-
const size_t rule_id = static_cast<size_t>(pos->value);
713-
const llama_grammar_element * subpos = rules[rule_id].data();
714-
do {
715-
// init new stack without the top (pos)
716-
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
717-
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
718-
// if this rule ref is followed by another element, add that to stack
719-
new_stack.push_back(pos + 1);
720-
}
721-
if (!llama_grammar_is_end_of_sequence(subpos)) {
722-
// if alternate is nonempty, add to stack
723-
new_stack.push_back(subpos);
724-
}
725-
llama_grammar_advance_stack_memo(rules, new_stack, new_stacks, stacks_cache);
726-
while (!llama_grammar_is_end_of_sequence(subpos)) {
727-
// scan to end of alternate def
728-
subpos++;
729-
}
730-
if (subpos->type == LLAMA_GRETYPE_ALT) {
731-
// there's another alternate def of this rule to process
732-
subpos++;
733-
} else {
734-
break;
735-
}
736-
} while (true);
737-
break;
738-
}
739-
case LLAMA_GRETYPE_CHAR:
740-
case LLAMA_GRETYPE_CHAR_NOT:
741-
case LLAMA_GRETYPE_CHAR_ANY:
742-
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
743-
// only add the stack if it's not a duplicate of one we already have
744-
new_stacks.emplace_back(stack);
745-
}
746-
break;
747-
default:
748-
// end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
749-
// (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
750-
// those
751-
GGML_ABORT("fatal error");
752-
}
753-
}
754-
755-
static void llama_grammar_advance_stack_memo(
756-
const llama_grammar_rules & rules,
757-
const llama_grammar_stack & stack,
758-
llama_grammar_stacks & new_stacks,
759-
llama_grammar_stacks_cache & stacks_cache) {
760-
761-
llama_grammar_stacks advanced_stacks;
762-
// Look if stack is already in memory
763-
auto it = stacks_cache.find(stack);
764-
if (it != stacks_cache.end()) {
765-
advanced_stacks = it->second;
766-
} else {
767-
// Advance stacks with memoization
768-
llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks, stacks_cache);
769-
stacks_cache.insert(make_pair(stack, advanced_stacks));
770-
}
771-
// Add the advanced stacks to new_stacks avoiding duplicates
772-
for (const auto & new_stack : advanced_stacks) {
773-
if (std::find(new_stacks.begin(), new_stacks.end(), new_stack) == new_stacks.end()) {
774-
new_stacks.emplace_back(new_stack);
775-
}
776-
}
777-
778-
}
779-
780685
// transforms a grammar pushdown stack into N possible stacks, all ending
781686
// at a character range (terminal element)
782687
static void llama_grammar_advance_stack(
@@ -917,11 +822,15 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
917822
return grammar->stacks;
918823
}
919824

920-
void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
921-
llama_grammar_stacks stacks_new;
922-
stacks_new.reserve(grammar->stacks.size());
825+
void llama_grammar_accept(
826+
const llama_grammar_rules & rules,
827+
const llama_grammar_stacks & stacks,
828+
const uint32_t chr,
829+
llama_grammar_stacks & stacks_new) {
830+
stacks_new.clear();
831+
stacks_new.reserve(stacks.size());
923832

924-
for (const auto & stack : grammar->stacks) {
833+
for (const auto & stack : stacks) {
925834
if (stack.empty()) {
926835
continue;
927836
}
@@ -935,11 +844,9 @@ void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
935844
if (!llama_grammar_is_end_of_sequence(pos)) {
936845
new_stack.push_back(pos);
937846
}
938-
llama_grammar_advance_stack_memo(grammar->rules, new_stack, stacks_new, grammar->stacks_cache);
847+
llama_grammar_advance_stack(rules, new_stack, stacks_new);
939848
}
940849
}
941-
942-
grammar->stacks = std::move(stacks_new);
943850
}
944851

945852
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
@@ -1031,15 +938,14 @@ struct llama_grammar * llama_grammar_init_impl(
1031938

1032939
// loop over alternates of start rule to build initial stacks
1033940
llama_grammar_stacks stacks;
1034-
llama_grammar_stacks_cache stacks_cache;
1035941
pos = vec_rules[start_rule_index].data();
1036942
do {
1037943
llama_grammar_stack stack;
1038944
if (!llama_grammar_is_end_of_sequence(pos)) {
1039945
// if alternate is nonempty, add to stack
1040946
stack.push_back(pos);
1041947
}
1042-
llama_grammar_advance_stack_memo(vec_rules, stack, stacks, stacks_cache);
948+
llama_grammar_advance_stack(vec_rules, stack, stacks);
1043949
while (!llama_grammar_is_end_of_sequence(pos)) {
1044950
// scan to end of alternate def
1045951
pos++;
@@ -1055,7 +961,7 @@ struct llama_grammar * llama_grammar_init_impl(
1055961
// Important: vec_rules has to be moved here, not copied, because stacks contains
1056962
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
1057963
// then the pointers would be invalidated when the local vec_rules goes out of scope.
1058-
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), std::move(stacks_cache), {}, };
964+
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
1059965
}
1060966

1061967
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
@@ -1110,15 +1016,14 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
11101016

11111017
// loop over alternates of start rule to build initial stacks
11121018
llama_grammar_stacks stacks;
1113-
llama_grammar_stacks_cache stacks_cache;
11141019
pos = vec_rules[start_rule_index].data();
11151020
do {
11161021
llama_grammar_stack stack;
11171022
if (!llama_grammar_is_end_of_sequence(pos)) {
11181023
// if alternate is nonempty, add to stack
11191024
stack.push_back(pos);
11201025
}
1121-
llama_grammar_advance_stack_memo(vec_rules, stack, stacks, stacks_cache);
1026+
llama_grammar_advance_stack(vec_rules, stack, stacks);
11221027
while (!llama_grammar_is_end_of_sequence(pos)) {
11231028
// scan to end of alternate def
11241029
pos++;
@@ -1134,7 +1039,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
11341039
// Important: vec_rules has to be moved here, not copied, because stacks contains
11351040
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
11361041
// then the pointers would be invalidated when the local vec_rules goes out of scope.
1137-
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), std::move(stacks_cache), {}, };
1042+
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
11381043
}
11391044

11401045
void llama_grammar_free_impl(struct llama_grammar * grammar) {
@@ -1146,21 +1051,15 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
11461051
}
11471052

11481053
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
1149-
llama_grammar * result = new llama_grammar {
1150-
grammar.vocab,
1151-
grammar.rules,
1152-
grammar.stacks,
1153-
grammar.stacks_cache,
1154-
grammar.partial_utf8,
1155-
};
1054+
llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, };
11561055

11571056
// redirect elements in stacks to point to new rules
11581057
for (size_t is = 0; is < result->stacks.size(); is++) {
11591058
for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
11601059
for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
11611060
for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
11621061
if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
1163-
result->stacks[is][ie] = &result->rules[ir0][ir1];
1062+
result->stacks[is][ie] = &result->rules[ir0][ir1];
11641063
}
11651064
}
11661065
}
@@ -1227,8 +1126,11 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
12271126
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
12281127
const auto & code_points = decoded.first;
12291128

1129+
llama_grammar_stacks stacks_new;
1130+
12301131
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
1231-
llama_grammar_accept(&grammar, *it);
1132+
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new);
1133+
grammar.stacks = std::move(stacks_new);
12321134
}
12331135

12341136
grammar.partial_utf8 = decoded.second;

src/llama-grammar.h

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include "llama-impl.h"
44

55
#include <map>
6-
#include <unordered_map>
76

87
struct llama_vocab;
98

@@ -59,27 +58,18 @@ using llama_grammar_rules = std::vector<llama_grammar_rule>;
5958
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
6059
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
6160

62-
struct VectorPointerHash {
63-
size_t operator()(const llama_grammar_stack & v) const {
64-
size_t seed = v.size();
65-
for (const auto* ptr : v) {
66-
seed ^= std::hash<const llama_grammar_element*>()(ptr) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
67-
}
68-
return seed;
69-
}
70-
};
71-
72-
using llama_grammar_stacks_cache = std::unordered_map<llama_grammar_stack, llama_grammar_stacks, VectorPointerHash>;
73-
74-
// TODO: remove, needed for tests atm
7561
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
7662
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
7763

7864
// takes a set of possible pushdown stacks on a grammar, which are required to
7965
// be positioned at a character range (see `llama_grammar_advance_stack`), and
8066
// produces the N possible stacks if the given char is accepted at those
8167
// positions
82-
void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr);
68+
void llama_grammar_accept(
69+
const llama_grammar_rules & rules,
70+
const llama_grammar_stacks & stacks,
71+
uint32_t chr,
72+
llama_grammar_stacks & stacks_new);
8373

8474
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
8575
const llama_grammar_rules & rules,
@@ -123,9 +113,6 @@ struct llama_grammar {
123113
const llama_grammar_rules rules; // TODO: shared ptr
124114
llama_grammar_stacks stacks;
125115

126-
// cache N possible stacks from a stack
127-
llama_grammar_stacks_cache stacks_cache;
128-
129116
// buffer for partially generated UTF-8 sequence from accepted tokens
130117
llama_partial_utf8 partial_utf8;
131118
};

0 commit comments

Comments
 (0)