@@ -13174,28 +13174,68 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
13174
13174
return rejects;
13175
13175
}
13176
13176
13177
- //
13178
- // grammar - external
13179
- //
13177
+ static bool llama_grammar_detect_left_recursion(
13178
+ const std::vector<std::vector<llama_grammar_element>> & rules,
13179
+ size_t rule_index,
13180
+ std::vector<bool> * rules_visited,
13181
+ std::vector<bool> * rules_in_progress,
13182
+ std::vector<bool> * rules_may_be_empty);
13180
13183
13181
- enum detect_left_recursion_status {
13182
- // haven't searched this nonterminal
13183
- LLAMA_LEFT_REC_NOT_SEARCHED = 0,
13184
+ static bool llama_grammar_detect_left_recursion(
13185
+ const std::vector<std::vector<llama_grammar_element>> & rules,
13186
+ size_t rule_index,
13187
+ std::vector<bool> * rules_visited,
13188
+ std::vector<bool> * rules_in_progress,
13189
+ std::vector<bool> * rules_may_be_empty) {
13190
+ if ((*rules_in_progress)[rule_index]) {
13191
+ return true;
13192
+ }
13184
13193
13185
- // searching this nonterminal in progress
13186
- LLAMA_LEFT_REC_IN_PROGRESS = 1,
13194
+ (*rules_in_progress)[rule_index] = true;
13187
13195
13188
- // finished searching this nonterminal
13189
- LLAMA_LEFT_REC_FINISHED_SEARCH = 2,
13196
+ const std::vector<llama_grammar_element> & rule = rules[rule_index];
13190
13197
13191
- // detected a cycle
13192
- LLAMA_LEFT_REC_FOUND_CYCLE = 3,
13193
- };
13198
+ // First check if the rule might produce the empty string. This could be done combined with the second
13199
+ // step but it's more readable as two steps.
13200
+ bool at_rule_start = true;
13201
+ for (size_t i = 0; i < rule.size(); i++) {
13202
+ if (llama_grammar_is_end_of_sequence(&rule[i])) {
13203
+ if (at_rule_start) {
13204
+ (*rules_may_be_empty)[rule_index] = true;
13205
+ break;
13206
+ }
13207
+ at_rule_start = true;
13208
+ } else {
13209
+ at_rule_start = false;
13210
+ }
13211
+ }
13194
13212
13195
- static void detect_left_recursion(
13196
- const std::vector<std::vector<llama_grammar_element>> & rules,
13197
- size_t rule_index,
13198
- std::vector<enum detect_left_recursion_status> * rules_visited);
13213
+ // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
13214
+ // be empty)
13215
+ bool recurse_into_nonterminal = true;
13216
+ for (size_t i = 0; i < rule.size(); i++) {
13217
+ if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
13218
+ if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
13219
+ return true;
13220
+ }
13221
+ if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
13222
+ recurse_into_nonterminal = false;
13223
+ }
13224
+ } else if (llama_grammar_is_end_of_sequence(&rule[i])) {
13225
+ recurse_into_nonterminal = true;
13226
+ } else {
13227
+ recurse_into_nonterminal = false;
13228
+ }
13229
+ }
13230
+
13231
+ (*rules_in_progress)[rule_index] = false;
13232
+ (*rules_visited)[rule_index] = true;
13233
+ return false;
13234
+ }
13235
+
13236
+ //
13237
+ // grammar - external
13238
+ //
13199
13239
13200
13240
struct llama_grammar * llama_grammar_init(
13201
13241
const llama_grammar_element ** rules,
@@ -13213,14 +13253,16 @@ struct llama_grammar * llama_grammar_init(
13213
13253
}
13214
13254
13215
13255
// Check for left recursion
13216
- std::vector<enum detect_left_recursion_status> rules_visited(n_rules);
13256
+ std::vector<bool> rules_visited(n_rules);
13257
+ std::vector<bool> rules_in_progress(n_rules);
13258
+ std::vector<bool> rules_may_be_empty(n_rules);
13217
13259
for (size_t i = 0; i < n_rules; i++) {
13218
- detect_left_recursion(vec_rules, i, & rules_visited);
13219
- }
13220
-
13221
- auto iter = std::find(rules_visited.begin(), rules_visited.end(), LLAMA_LEFT_REC_FOUND_CYCLE);
13222
- if (iter != rules_visited.end()) {
13223
- throw std::runtime_error(format("unsupported grammar, left recursion detected for nonterminal at index %d", (int)(iter - rules_visited.begin())));
13260
+ if ( rules_visited[i]) {
13261
+ continue;
13262
+ }
13263
+ if (llama_grammar_detect_left_recursion(vec_rules, i, & rules_visited, &rules_in_progress, &rules_may_be_empty)) {
13264
+ throw std::runtime_error(format("unsupported grammar, left recursion detected for nonterminal at index %zu", i));
13265
+ }
13224
13266
}
13225
13267
13226
13268
// loop over alternates of start rule to build initial stacks
@@ -13251,45 +13293,6 @@ struct llama_grammar * llama_grammar_init(
13251
13293
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
13252
13294
}
13253
13295
13254
- static void detect_left_recursion(
13255
- const std::vector<std::vector<llama_grammar_element>> & rules,
13256
- size_t rule_index,
13257
- std::vector<enum detect_left_recursion_status> * rules_visited) {
13258
-
13259
- int visit_status = (*rules_visited)[rule_index];
13260
- if (visit_status == LLAMA_LEFT_REC_IN_PROGRESS) {
13261
- // in progress -- we're in a cycle
13262
- (*rules_visited)[rule_index] = LLAMA_LEFT_REC_FOUND_CYCLE;
13263
- return;
13264
- } else if (visit_status == LLAMA_LEFT_REC_NOT_SEARCHED) {
13265
- // haven't visited yet. mark in progress, recurse, then mark complete.
13266
- // mark in progress
13267
- (*rules_visited)[rule_index] = LLAMA_LEFT_REC_IN_PROGRESS;
13268
-
13269
- // recurse
13270
- const std::vector<llama_grammar_element> & rule = rules[rule_index];
13271
- size_t i = 0;
13272
- do {
13273
- if (rule[i].type == LLAMA_GRETYPE_RULE_REF) {
13274
- detect_left_recursion(rules, (size_t)rule[i].value, rules_visited);
13275
- }
13276
- while (!llama_grammar_is_end_of_sequence(&rule[i])) {
13277
- i++;
13278
- }
13279
- i++;
13280
- } while (i < rule.size());
13281
-
13282
- // mark complete, but only if the recursive call didn't mark a cycle.
13283
- // that doesn't mean there's definitely no cycle *for this rule* -- the recursive call
13284
- // might have found a different cycle and stopped early.
13285
- if ((*rules_visited)[rule_index] == LLAMA_LEFT_REC_IN_PROGRESS) {
13286
- (*rules_visited)[rule_index] = LLAMA_LEFT_REC_FINISHED_SEARCH;
13287
- }
13288
- }
13289
-
13290
- return;
13291
- }
13292
-
13293
13296
void llama_grammar_free(struct llama_grammar * grammar) {
13294
13297
delete grammar;
13295
13298
}
0 commit comments