@@ -682,101 +682,6 @@ static bool llama_grammar_match_partial_char(
682
682
return !is_positive_char;
683
683
}
684
684
685
- // transforms a grammar pushdown stack into N possible stacks, all ending
686
- // at a character range (terminal element)
687
- // additionally memoizes the stack to its possible stacks by mapping
688
- // < llama_grammar_stack, llama_grammar_stacks >
689
-
690
- static void llama_grammar_advance_stack_memo (
691
- const llama_grammar_rules & rules,
692
- const llama_grammar_stack & stack,
693
- llama_grammar_stacks & new_stacks,
694
- llama_grammar_stacks_cache & stacks_cache);
695
-
696
- static void llama_grammar_advance_stack_memo_impl (
697
- const llama_grammar_rules & rules,
698
- const llama_grammar_stack & stack,
699
- llama_grammar_stacks & new_stacks,
700
- llama_grammar_stacks_cache & stacks_cache) {
701
- if (stack.empty ()) {
702
- if (std::find (new_stacks.begin (), new_stacks.end (), stack) == new_stacks.end ()) {
703
- new_stacks.emplace_back (stack);
704
- }
705
- return ;
706
- }
707
-
708
- const llama_grammar_element * pos = stack.back ();
709
-
710
- switch (pos->type ) {
711
- case LLAMA_GRETYPE_RULE_REF: {
712
- const size_t rule_id = static_cast <size_t >(pos->value );
713
- const llama_grammar_element * subpos = rules[rule_id].data ();
714
- do {
715
- // init new stack without the top (pos)
716
- llama_grammar_stack new_stack (stack.begin (), stack.end () - 1 );
717
- if (!llama_grammar_is_end_of_sequence (pos + 1 )) {
718
- // if this rule ref is followed by another element, add that to stack
719
- new_stack.push_back (pos + 1 );
720
- }
721
- if (!llama_grammar_is_end_of_sequence (subpos)) {
722
- // if alternate is nonempty, add to stack
723
- new_stack.push_back (subpos);
724
- }
725
- llama_grammar_advance_stack_memo (rules, new_stack, new_stacks, stacks_cache);
726
- while (!llama_grammar_is_end_of_sequence (subpos)) {
727
- // scan to end of alternate def
728
- subpos++;
729
- }
730
- if (subpos->type == LLAMA_GRETYPE_ALT) {
731
- // there's another alternate def of this rule to process
732
- subpos++;
733
- } else {
734
- break ;
735
- }
736
- } while (true );
737
- break ;
738
- }
739
- case LLAMA_GRETYPE_CHAR:
740
- case LLAMA_GRETYPE_CHAR_NOT:
741
- case LLAMA_GRETYPE_CHAR_ANY:
742
- if (std::find (new_stacks.begin (), new_stacks.end (), stack) == new_stacks.end ()) {
743
- // only add the stack if it's not a duplicate of one we already have
744
- new_stacks.emplace_back (stack);
745
- }
746
- break ;
747
- default :
748
- // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
749
- // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
750
- // those
751
- GGML_ABORT (" fatal error" );
752
- }
753
- }
754
-
755
- static void llama_grammar_advance_stack_memo (
756
- const llama_grammar_rules & rules,
757
- const llama_grammar_stack & stack,
758
- llama_grammar_stacks & new_stacks,
759
- llama_grammar_stacks_cache & stacks_cache) {
760
-
761
- llama_grammar_stacks advanced_stacks;
762
- // Look if stack is already in memory
763
- auto it = stacks_cache.find (stack);
764
- if (it != stacks_cache.end ()) {
765
- advanced_stacks = it->second ;
766
- } else {
767
- // Advance stacks with memoization
768
- llama_grammar_advance_stack_memo_impl (rules, stack, advanced_stacks, stacks_cache);
769
- stacks_cache.insert (make_pair (stack, advanced_stacks));
770
- }
771
- // Add the advanced stacks to new_stacks avoiding duplicates
772
- for (const auto & new_stack : advanced_stacks) {
773
- if (std::find (new_stacks.begin (), new_stacks.end (), new_stack) == new_stacks.end ()) {
774
- new_stacks.emplace_back (new_stack);
775
- }
776
- }
777
-
778
- }
779
-
780
685
// transforms a grammar pushdown stack into N possible stacks, all ending
781
686
// at a character range (terminal element)
782
687
static void llama_grammar_advance_stack (
@@ -917,11 +822,15 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
917
822
return grammar->stacks ;
918
823
}
919
824
920
- void llama_grammar_accept (struct llama_grammar * grammar, uint32_t chr) {
921
- llama_grammar_stacks stacks_new;
922
- stacks_new.reserve (grammar->stacks .size ());
825
+ void llama_grammar_accept (
826
+ const llama_grammar_rules & rules,
827
+ const llama_grammar_stacks & stacks,
828
+ const uint32_t chr,
829
+ llama_grammar_stacks & stacks_new) {
830
+ stacks_new.clear ();
831
+ stacks_new.reserve (stacks.size ());
923
832
924
- for (const auto & stack : grammar-> stacks ) {
833
+ for (const auto & stack : stacks) {
925
834
if (stack.empty ()) {
926
835
continue ;
927
836
}
@@ -935,11 +844,9 @@ void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
935
844
if (!llama_grammar_is_end_of_sequence (pos)) {
936
845
new_stack.push_back (pos);
937
846
}
938
- llama_grammar_advance_stack_memo (grammar-> rules , new_stack, stacks_new, grammar-> stacks_cache );
847
+ llama_grammar_advance_stack ( rules, new_stack, stacks_new);
939
848
}
940
849
}
941
-
942
- grammar->stacks = std::move (stacks_new);
943
850
}
944
851
945
852
llama_grammar_candidates llama_grammar_reject_candidates_for_stack (
@@ -1031,15 +938,14 @@ struct llama_grammar * llama_grammar_init_impl(
1031
938
1032
939
// loop over alternates of start rule to build initial stacks
1033
940
llama_grammar_stacks stacks;
1034
- llama_grammar_stacks_cache stacks_cache;
1035
941
pos = vec_rules[start_rule_index].data ();
1036
942
do {
1037
943
llama_grammar_stack stack;
1038
944
if (!llama_grammar_is_end_of_sequence (pos)) {
1039
945
// if alternate is nonempty, add to stack
1040
946
stack.push_back (pos);
1041
947
}
1042
- llama_grammar_advance_stack_memo (vec_rules, stack, stacks, stacks_cache );
948
+ llama_grammar_advance_stack (vec_rules, stack, stacks);
1043
949
while (!llama_grammar_is_end_of_sequence (pos)) {
1044
950
// scan to end of alternate def
1045
951
pos++;
@@ -1055,7 +961,7 @@ struct llama_grammar * llama_grammar_init_impl(
1055
961
// Important: vec_rules has to be moved here, not copied, because stacks contains
1056
962
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
1057
963
// then the pointers would be invalidated when the local vec_rules goes out of scope.
1058
- return new llama_grammar { vocab, std::move (vec_rules), std::move (stacks), std::move (stacks_cache), {}, };
964
+ return new llama_grammar { vocab, std::move (vec_rules), std::move (stacks), {}, };
1059
965
}
1060
966
1061
967
struct llama_grammar * llama_grammar_init_impl (const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
@@ -1110,15 +1016,14 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
1110
1016
1111
1017
// loop over alternates of start rule to build initial stacks
1112
1018
llama_grammar_stacks stacks;
1113
- llama_grammar_stacks_cache stacks_cache;
1114
1019
pos = vec_rules[start_rule_index].data ();
1115
1020
do {
1116
1021
llama_grammar_stack stack;
1117
1022
if (!llama_grammar_is_end_of_sequence (pos)) {
1118
1023
// if alternate is nonempty, add to stack
1119
1024
stack.push_back (pos);
1120
1025
}
1121
- llama_grammar_advance_stack_memo (vec_rules, stack, stacks, stacks_cache );
1026
+ llama_grammar_advance_stack (vec_rules, stack, stacks);
1122
1027
while (!llama_grammar_is_end_of_sequence (pos)) {
1123
1028
// scan to end of alternate def
1124
1029
pos++;
@@ -1134,7 +1039,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
1134
1039
// Important: vec_rules has to be moved here, not copied, because stacks contains
1135
1040
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
1136
1041
// then the pointers would be invalidated when the local vec_rules goes out of scope.
1137
- return new llama_grammar { vocab, std::move (vec_rules), std::move (stacks), std::move (stacks_cache), {}, };
1042
+ return new llama_grammar { vocab, std::move (vec_rules), std::move (stacks), {}, };
1138
1043
}
1139
1044
1140
1045
void llama_grammar_free_impl (struct llama_grammar * grammar) {
@@ -1146,21 +1051,15 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
1146
1051
}
1147
1052
1148
1053
struct llama_grammar * llama_grammar_clone_impl (const struct llama_grammar & grammar) {
1149
- llama_grammar * result = new llama_grammar {
1150
- grammar.vocab ,
1151
- grammar.rules ,
1152
- grammar.stacks ,
1153
- grammar.stacks_cache ,
1154
- grammar.partial_utf8 ,
1155
- };
1054
+ llama_grammar * result = new llama_grammar { grammar.vocab , grammar.rules , grammar.stacks , grammar.partial_utf8 , };
1156
1055
1157
1056
// redirect elements in stacks to point to new rules
1158
1057
for (size_t is = 0 ; is < result->stacks .size (); is++) {
1159
1058
for (size_t ie = 0 ; ie < result->stacks [is].size (); ie++) {
1160
1059
for (size_t ir0 = 0 ; ir0 < grammar.rules .size (); ir0++) {
1161
1060
for (size_t ir1 = 0 ; ir1 < grammar.rules [ir0].size (); ir1++) {
1162
1061
if (grammar.stacks [is][ie] == &grammar.rules [ir0][ir1]) {
1163
- result->stacks [is][ie] = &result->rules [ir0][ir1];
1062
+ result->stacks [is][ie] = &result->rules [ir0][ir1];
1164
1063
}
1165
1064
}
1166
1065
}
@@ -1227,8 +1126,11 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
1227
1126
const auto decoded = decode_utf8 (piece, grammar.partial_utf8 );
1228
1127
const auto & code_points = decoded.first ;
1229
1128
1129
+ llama_grammar_stacks stacks_new;
1130
+
1230
1131
for (auto it = code_points.begin (), end = code_points.end () - 1 ; it != end; ++it) {
1231
- llama_grammar_accept (&grammar, *it);
1132
+ llama_grammar_accept (grammar.rules , grammar.stacks , *it, stacks_new);
1133
+ grammar.stacks = std::move (stacks_new);
1232
1134
}
1233
1135
1234
1136
grammar.partial_utf8 = decoded.second ;
0 commit comments