-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Add grammar-based sampling #572
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
After a further review, I noticed that the grammar sampling doesn't work after one completion. if (n_past > 0) {
if (is_interacting) {
// reset grammar state if we're restarting generation
if (grammar != NULL) {
llama_grammar_free(grammar);
std::vector<const llama_grammar_element *> grammar_rules(
parsed_grammar.c_rules());
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(),
parsed_grammar.symbol_ids.at("root"));
}
}
is_interacting = false;
}
} class LlamaGrammar:
...
def reset(self) -> None:
llama_cpp.llama_grammar_free(self.grammar)
self.grammar = llama_cpp.llama_grammar_init(
self.rules, self.n_rules, self.start_rule_index
)
class Llama:
def generate(...):
...
if reset:
self.reset()
if self.grammar is not None:
self.grammar.reset()
while True:
self.eval(tokens)
... |
@c0sogi amazing work, I was planning on pulling in the grammar based sampling but ran into an issue with the parser. I'll merge this in but I'll likely move the grammar directly to the generate and completion calls, this way a different grammar can be used for the same Llama model. |
@abetlen My mistake. These four enum-related parts if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END.value:
raise RuntimeError(
"malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id)
)
...
if case is llama_gretype.LLAMA_GRETYPE_END.value:
raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i))
...
def is_char_element(elem: LlamaGrammarElement) -> bool:
return elem.type in (
llama_gretype.LLAMA_GRETYPE_CHAR.value,
llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value,
llama_gretype.LLAMA_GRETYPE_CHAR_ALT.value,
llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value,
)
...
if rule[i + 1].type in (
llama_gretype.LLAMA_GRETYPE_CHAR_ALT.value,
llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value,
): should be these: if rule.empty() or rule.back().type is not llama_gretype.LLAMA_GRETYPE_END:
raise RuntimeError(
"malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id)
)
...
if case is llama_gretype.LLAMA_GRETYPE_END:
raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i))
...
def is_char_element(elem: LlamaGrammarElement) -> bool:
return elem.type in (
llama_gretype.LLAMA_GRETYPE_CHAR,
llama_gretype.LLAMA_GRETYPE_CHAR_NOT,
llama_gretype.LLAMA_GRETYPE_CHAR_ALT,
llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER,
)
...
if rule[i + 1].type in (
llama_gretype.LLAMA_GRETYPE_CHAR_ALT,
llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER,
): But I don't know why the previous one works. Haha 😅 |
I am trying to make sure that my output follow a json format every time, i stumbled upon jsonformer and from there i stumbled upon grammar-based sampling, I used json-schema-to-grammar.py to convert json schema. I want to know if grammar based sampling is used for this specific purpose and if so then how do i use it. Json schema
Llama grammar
Here is my code
This is the error i am getting
|
I think your scheme is correct, but I think it's because of the typos in llama_grammar.py |
@c0sogi I tested it out and i am getting same result
|
|
Recently, grammar based sampling was merged into llama.cpp.
However, there's no explicit parser API we can currently use in Python. Therefore, I translated the grammar-parser.cpp into llama_grammar.py.
I've tested it using
vendor/llama.cpp/grammars/json.gbnf,
and the output of parsed grammar was perfectly the same as compiled version. See the example below. I hope this will help implementing function call someday!Test code:
Output:
Once this merged, I will continue PR
function call
feature example.This will parse real python function into
grammar
string. See the test result below