Skip to content

Commit 5c5911d

Browse files
committed
Remove custom enum, rename left recursion check and move to "grammar internal" section, add handling for edge case where a leftmost nonterminal may be empty
1 parent 65176e7 commit 5c5911d

File tree

2 files changed

+77
-65
lines changed

2 files changed

+77
-65
lines changed

llama.cpp

+66-63
Original file line numberDiff line numberDiff line change
@@ -13174,28 +13174,68 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
1317413174
return rejects;
1317513175
}
1317613176

13177-
//
13178-
// grammar - external
13179-
//
13177+
static bool llama_grammar_detect_left_recursion(
13178+
const std::vector<std::vector<llama_grammar_element>> & rules,
13179+
size_t rule_index,
13180+
std::vector<bool> * rules_visited,
13181+
std::vector<bool> * rules_in_progress,
13182+
std::vector<bool> * rules_may_be_empty);
1318013183

13181-
enum detect_left_recursion_status {
13182-
// haven't searched this nonterminal
13183-
LLAMA_LEFT_REC_NOT_SEARCHED = 0,
13184+
static bool llama_grammar_detect_left_recursion(
13185+
const std::vector<std::vector<llama_grammar_element>> & rules,
13186+
size_t rule_index,
13187+
std::vector<bool> * rules_visited,
13188+
std::vector<bool> * rules_in_progress,
13189+
std::vector<bool> * rules_may_be_empty) {
13190+
if ((*rules_in_progress)[rule_index]) {
13191+
return true;
13192+
}
1318413193

13185-
// searching this nonterminal in progress
13186-
LLAMA_LEFT_REC_IN_PROGRESS = 1,
13194+
(*rules_in_progress)[rule_index] = true;
1318713195

13188-
// finished searching this nonterminal
13189-
LLAMA_LEFT_REC_FINISHED_SEARCH = 2,
13196+
const std::vector<llama_grammar_element> & rule = rules[rule_index];
1319013197

13191-
// detected a cycle
13192-
LLAMA_LEFT_REC_FOUND_CYCLE = 3,
13193-
};
13198+
// First check if the rule might produce the empty string. This could be done combined with the second
13199+
// step but it's more readable as two steps.
13200+
bool at_rule_start = true;
13201+
for (size_t i = 0; i < rule.size(); i++) {
13202+
if (llama_grammar_is_end_of_sequence(&rule[i])) {
13203+
if (at_rule_start) {
13204+
(*rules_may_be_empty)[rule_index] = true;
13205+
break;
13206+
}
13207+
at_rule_start = true;
13208+
} else {
13209+
at_rule_start = false;
13210+
}
13211+
}
1319413212

13195-
static void detect_left_recursion(
13196-
const std::vector<std::vector<llama_grammar_element>> & rules,
13197-
size_t rule_index,
13198-
std::vector<enum detect_left_recursion_status> * rules_visited);
13213+
// Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
13214+
// be empty)
13215+
bool recurse_into_nonterminal = true;
13216+
for (size_t i = 0; i < rule.size(); i++) {
13217+
if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
13218+
if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
13219+
return true;
13220+
}
13221+
if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
13222+
recurse_into_nonterminal = false;
13223+
}
13224+
} else if (llama_grammar_is_end_of_sequence(&rule[i])) {
13225+
recurse_into_nonterminal = true;
13226+
} else {
13227+
recurse_into_nonterminal = false;
13228+
}
13229+
}
13230+
13231+
(*rules_in_progress)[rule_index] = false;
13232+
(*rules_visited)[rule_index] = true;
13233+
return false;
13234+
}
13235+
13236+
//
13237+
// grammar - external
13238+
//
1319913239

1320013240
struct llama_grammar * llama_grammar_init(
1320113241
const llama_grammar_element ** rules,
@@ -13213,14 +13253,16 @@ struct llama_grammar * llama_grammar_init(
1321313253
}
1321413254

1321513255
// Check for left recursion
13216-
std::vector<enum detect_left_recursion_status> rules_visited(n_rules);
13256+
std::vector<bool> rules_visited(n_rules);
13257+
std::vector<bool> rules_in_progress(n_rules);
13258+
std::vector<bool> rules_may_be_empty(n_rules);
1321713259
for (size_t i = 0; i < n_rules; i++) {
13218-
detect_left_recursion(vec_rules, i, &rules_visited);
13219-
}
13220-
13221-
auto iter = std::find(rules_visited.begin(), rules_visited.end(), LLAMA_LEFT_REC_FOUND_CYCLE);
13222-
if (iter != rules_visited.end()) {
13223-
throw std::runtime_error(format("unsupported grammar, left recursion detected for nonterminal at index %d", (int)(iter - rules_visited.begin())));
13260+
if (rules_visited[i]) {
13261+
continue;
13262+
}
13263+
if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
13264+
throw std::runtime_error(format("unsupported grammar, left recursion detected for nonterminal at index %zu", i));
13265+
}
1322413266
}
1322513267

1322613268
// loop over alternates of start rule to build initial stacks
@@ -13251,45 +13293,6 @@ struct llama_grammar * llama_grammar_init(
1325113293
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
1325213294
}
1325313295

13254-
static void detect_left_recursion(
13255-
const std::vector<std::vector<llama_grammar_element>> & rules,
13256-
size_t rule_index,
13257-
std::vector<enum detect_left_recursion_status> * rules_visited) {
13258-
13259-
int visit_status = (*rules_visited)[rule_index];
13260-
if (visit_status == LLAMA_LEFT_REC_IN_PROGRESS) {
13261-
// in progress -- we're in a cycle
13262-
(*rules_visited)[rule_index] = LLAMA_LEFT_REC_FOUND_CYCLE;
13263-
return;
13264-
} else if (visit_status == LLAMA_LEFT_REC_NOT_SEARCHED) {
13265-
// haven't visited yet. mark in progress, recurse, then mark complete.
13266-
// mark in progress
13267-
(*rules_visited)[rule_index] = LLAMA_LEFT_REC_IN_PROGRESS;
13268-
13269-
// recurse
13270-
const std::vector<llama_grammar_element> & rule = rules[rule_index];
13271-
size_t i = 0;
13272-
do {
13273-
if (rule[i].type == LLAMA_GRETYPE_RULE_REF) {
13274-
detect_left_recursion(rules, (size_t)rule[i].value, rules_visited);
13275-
}
13276-
while (!llama_grammar_is_end_of_sequence(&rule[i])) {
13277-
i++;
13278-
}
13279-
i++;
13280-
} while (i < rule.size());
13281-
13282-
// mark complete, but only if the recursive call didn't mark a cycle.
13283-
// that doesn't mean there's definitely no cycle *for this rule* -- the recursive call
13284-
// might have found a different cycle and stopped early.
13285-
if ((*rules_visited)[rule_index] == LLAMA_LEFT_REC_IN_PROGRESS) {
13286-
(*rules_visited)[rule_index] = LLAMA_LEFT_REC_FINISHED_SEARCH;
13287-
}
13288-
}
13289-
13290-
return;
13291-
}
13292-
1329313296
void llama_grammar_free(struct llama_grammar * grammar) {
1329413297
delete grammar;
1329513298
}

tests/test-grammar-integration.cpp

+11-2
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,14 @@ static llama_grammar* build_grammar(const std::string & grammar_str) {
2929
}
3030

3131
static bool test_build_grammar_fails(const std::string & grammar_str) {
32+
fprintf(stderr, "⚫ Testing failure for grammar: %s\n", grammar_str.c_str());
3233
bool grammar_fails = false;
3334
try {
3435
build_grammar(grammar_str);
35-
fprintf(stderr, "❌ Expected build failure, but succeeded: %s\n", grammar_str.c_str());
36+
fprintf(stderr, " ❌ Expected build failure, but succeeded\n");
3637
} catch (const std::exception & err) {
3738
grammar_fails = true;
38-
fprintf(stdout, "✅︎\n");
39+
fprintf(stdout, " ✅︎\n");
3940
}
4041
return grammar_fails;
4142
}
@@ -353,6 +354,14 @@ asdf ::= "a" | foo "b"
353354
foo ::= "c" | asdf "d" | "e")""";
354355
assert(test_build_grammar_fails(hard_str));
355356

357+
// Test yet even more complicated left recursion detection
358+
const std::string hardest_str = R"""(
359+
root ::= asdf
360+
asdf ::= "a" | foo "b"
361+
foo ::= "c" | empty asdf "d" | "e"
362+
empty ::= "blah" | )""";
363+
assert(test_build_grammar_fails(hardest_str));
364+
356365
fprintf(stderr, " ✅︎ Passed\n");
357366
}
358367

0 commit comments

Comments
 (0)