@@ -764,7 +764,7 @@ static void llama_grammar_advance_stack_memo(
764
764
if (it != stacks_cache.end ()) {
765
765
advanced_stacks = it->second ;
766
766
} else {
767
- // Advance stacks with memorization
767
+ // Advance stacks with memorization
768
768
llama_grammar_advance_stack_memo_impl (rules, stack, advanced_stacks, stacks_cache);
769
769
stacks_cache.insert (make_pair (stack, advanced_stacks));
770
770
}
@@ -917,20 +917,11 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
917
917
return grammar->stacks ;
918
918
}
919
919
920
- llama_grammar_stacks_cache & llama_grammar_get_stacks_cache (struct llama_grammar * grammar) {
921
- return grammar->stacks_cache ;
922
- }
923
-
924
- void llama_grammar_accept (
925
- const llama_grammar_rules & rules,
926
- const llama_grammar_stacks & stacks,
927
- const uint32_t chr,
928
- llama_grammar_stacks & stacks_new,
929
- llama_grammar_stacks_cache & stacks_cache) {
930
- stacks_new.clear ();
931
- stacks_new.reserve (stacks.size ());
920
+ void llama_grammar_accept (struct llama_grammar * grammar, uint32_t chr) {
921
+ llama_grammar_stacks stacks_new;
922
+ stacks_new.reserve (grammar->stacks .size ());
932
923
933
- for (const auto & stack : stacks) {
924
+ for (const auto & stack : grammar-> stacks ) {
934
925
if (stack.empty ()) {
935
926
continue ;
936
927
}
@@ -944,9 +935,11 @@ void llama_grammar_accept(
944
935
if (!llama_grammar_is_end_of_sequence (pos)) {
945
936
new_stack.push_back (pos);
946
937
}
947
- llama_grammar_advance_stack_memo (rules, new_stack, stacks_new, stacks_cache);
938
+ llama_grammar_advance_stack_memo (grammar-> rules , new_stack, stacks_new, grammar-> stacks_cache );
948
939
}
949
940
}
941
+
942
+ grammar->stacks = std::move (stacks_new);
950
943
}
951
944
952
945
llama_grammar_candidates llama_grammar_reject_candidates_for_stack (
@@ -1062,7 +1055,7 @@ struct llama_grammar * llama_grammar_init_impl(
1062
1055
// Important: vec_rules has to be moved here, not copied, because stacks contains
1063
1056
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
1064
1057
// then the pointers would be invalidated when the local vec_rules goes out of scope.
1065
- return new llama_grammar { vocab, std::move (vec_rules), std::move (stacks), {}, std::move (stacks_cache), };
1058
+ return new llama_grammar { vocab, std::move (vec_rules), std::move (stacks), std::move (stacks_cache), {} , };
1066
1059
}
1067
1060
1068
1061
struct llama_grammar * llama_grammar_init_impl (const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
@@ -1141,7 +1134,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
1141
1134
// Important: vec_rules has to be moved here, not copied, because stacks contains
1142
1135
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
1143
1136
// then the pointers would be invalidated when the local vec_rules goes out of scope.
1144
- return new llama_grammar { vocab, std::move (vec_rules), std::move (stacks), {}, std::move (stacks_cache), };
1137
+ return new llama_grammar { vocab, std::move (vec_rules), std::move (stacks), std::move (stacks_cache), {} , };
1145
1138
}
1146
1139
1147
1140
void llama_grammar_free_impl (struct llama_grammar * grammar) {
@@ -1153,15 +1146,21 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
1153
1146
}
1154
1147
1155
1148
struct llama_grammar * llama_grammar_clone_impl (const struct llama_grammar & grammar) {
1156
- llama_grammar * result = new llama_grammar { grammar.vocab , grammar.rules , grammar.stacks , grammar.partial_utf8 , };
1149
+ llama_grammar * result = new llama_grammar {
1150
+ grammar.vocab ,
1151
+ grammar.rules ,
1152
+ grammar.stacks ,
1153
+ grammar.stacks_cache ,
1154
+ grammar.partial_utf8 ,
1155
+ };
1157
1156
1158
1157
// redirect elements in stacks to point to new rules
1159
1158
for (size_t is = 0 ; is < result->stacks .size (); is++) {
1160
1159
for (size_t ie = 0 ; ie < result->stacks [is].size (); ie++) {
1161
1160
for (size_t ir0 = 0 ; ir0 < grammar.rules .size (); ir0++) {
1162
1161
for (size_t ir1 = 0 ; ir1 < grammar.rules [ir0].size (); ir1++) {
1163
1162
if (grammar.stacks [is][ie] == &grammar.rules [ir0][ir1]) {
1164
- result->stacks [is][ie] = &result->rules [ir0][ir1];
1163
+ result->stacks [is][ie] = &result->rules [ir0][ir1];
1165
1164
}
1166
1165
}
1167
1166
}
@@ -1228,11 +1227,8 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
1228
1227
const auto decoded = decode_utf8 (piece, grammar.partial_utf8 );
1229
1228
const auto & code_points = decoded.first ;
1230
1229
1231
- llama_grammar_stacks stacks_new;
1232
-
1233
1230
for (auto it = code_points.begin (), end = code_points.end () - 1 ; it != end; ++it) {
1234
- llama_grammar_accept (grammar.rules , grammar.stacks , *it, stacks_new, grammar.stacks_cache );
1235
- grammar.stacks = std::move (stacks_new);
1231
+ llama_grammar_accept (&grammar, *it);
1236
1232
}
1237
1233
1238
1234
grammar.partial_utf8 = decoded.second ;
0 commit comments