Skip to content

Commit b07713c

Browse files
author
c0sogi
committed
reset grammar for every generation
1 parent 418aa83 commit b07713c

File tree

2 files changed

+39
-95
lines changed

2 files changed

+39
-95
lines changed

llama_cpp/llama.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def __init__(
364364
)
365365
if grammar is not None:
366366
self.grammar = LlamaGrammar.from_file(
367-
grammar
367+
grammar, verbose=verbose
368368
) # type: Optional[LlamaGrammar]
369369
else:
370370
self.grammar = None
@@ -723,7 +723,6 @@ def generate(
723723
The generated tokens.
724724
"""
725725
assert self.ctx is not None
726-
727726
if reset and len(self._input_ids) > 0:
728727
longest_prefix = 0
729728
for a, b in zip(self._input_ids, tokens[:-1]):
@@ -741,6 +740,9 @@ def generate(
741740
if reset:
742741
self.reset()
743742

743+
if self.grammar is not None:
744+
self.grammar.reset()
745+
744746
while True:
745747
self.eval(tokens)
746748
token = self.sample(
@@ -1534,9 +1536,6 @@ def __del__(self):
15341536
if self.ctx is not None:
15351537
llama_cpp.llama_free(self.ctx)
15361538
self.ctx = None
1537-
if self.grammar is not None:
1538-
llama_cpp.llama_grammar_free(self.grammar.grammar)
1539-
self.grammar = None
15401539

15411540
def __getstate__(self):
15421541
return dict(

llama_cpp/llama_grammar.py

+35-90
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""C++ implementation of the llama grammar parser."""
22
# flake8: noqa
3-
import argparse
43
from pathlib import Path
54
import sys
65
from ctypes import * # type: ignore
@@ -19,7 +18,7 @@
1918
overload,
2019
)
2120

22-
import llama_cpp
21+
from . import llama_cpp
2322

2423
# Type aliases
2524
llama_grammar_element = llama_cpp.llama_grammar_element
@@ -41,11 +40,19 @@ class Sentinel:
4140
class LlamaGrammar:
4241
"""Keeps reference counts of all the arguments, so that they are not
4342
garbage collected by Python."""
43+
44+
def __del__(self) -> None:
45+
"""Free the grammar pointer when the object is deleted."""
46+
if self.grammar is not None:
47+
llama_cpp.llama_grammar_free(self.grammar)
48+
self.grammar = None
4449

4550
def __init__(
4651
self,
4752
parsed_grammar: "parse_state",
4853
) -> None:
54+
"""Initialize the grammar pointer from the parsed state."""
55+
self.parsed_grammar = parsed_grammar
4956
grammar_rules = (
5057
parsed_grammar.c_rules()
5158
) # type: std.vector[std.vector[llama_grammar_element]]
@@ -69,22 +76,25 @@ def __init__(
6976

7077
self.n_rules = c_size_t(grammar_rules.size())
7178
self.start_rule_index = c_size_t(parsed_grammar.symbol_ids.at("root"))
72-
self.grammar = self.init_grammar()
79+
self._grammar = llama_cpp.llama_grammar_init(
80+
self.rules, self.n_rules, self.start_rule_index
81+
)
7382

7483
@classmethod
75-
def from_string(cls, grammar: str) -> "LlamaGrammar":
84+
def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar":
7685
parsed_grammar = parse(const_char_p(grammar)) # type: parse_state
7786
if parsed_grammar.rules.empty():
7887
raise ValueError(
7988
f"{cls.from_string.__name__}: error parsing grammar file: parsed_grammar.rules is empty"
8089
)
81-
print(f"{cls.from_string.__name__} grammar:", file=sys.stderr)
82-
print_grammar(sys.stdout, parsed_grammar)
83-
print(file=sys.stderr)
90+
if verbose:
91+
print(f"{cls.from_string.__name__} grammar:", file=sys.stderr)
92+
print_grammar(sys.stdout, parsed_grammar)
93+
print(file=sys.stderr)
8494
return cls(parsed_grammar)
8595

8696
@classmethod
87-
def from_file(cls, file: Union[str, Path]) -> "LlamaGrammar":
97+
def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar":
8898
try:
8999
with open(file) as f:
90100
grammar = f.read()
@@ -94,14 +104,27 @@ def from_file(cls, file: Union[str, Path]) -> "LlamaGrammar":
94104
)
95105

96106
if grammar:
97-
return cls.from_string(grammar)
107+
return cls.from_string(grammar, verbose=verbose)
98108

99109
raise ValueError(
100110
f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty"
101111
)
102112

103-
def init_grammar(self) -> llama_grammar_p:
104-
return llama_cpp.llama_grammar_init(
113+
@property
114+
def grammar(self) -> llama_grammar_p:
115+
if self._grammar is None:
116+
raise ValueError(
117+
f"{self.__class__.__name__}.grammar: grammar is freed"
118+
)
119+
return self._grammar
120+
121+
@grammar.setter
122+
def grammar(self, value: Optional[llama_grammar_p]) -> None:
123+
self._grammar = value
124+
125+
def reset(self) -> None:
126+
llama_cpp.llama_grammar_free(self.grammar)
127+
self.grammar = llama_cpp.llama_grammar_init(
105128
self.rules, self.n_rules, self.start_rule_index
106129
)
107130

@@ -1216,82 +1239,4 @@ def print_grammar(file: TextIO, state: parse_state) -> None:
12161239
print(
12171240
f"{print_grammar.__name__}: error printing grammar: {err}",
12181241
file=sys.stderr,
1219-
)
1220-
1221-
1222-
# def convert_to_rules(
1223-
# llama_grammar_elements: std.vector[std.vector[llama_grammar_element]],
1224-
# ) -> Array[llama_grammar_element_p]:
1225-
# """Make an Array object that is used for `llama_grammer_init`"""
1226-
1227-
# # Step 1: Convert each list to llama_grammar_element array and get pointer
1228-
# element_arrays = [
1229-
# (llama_grammar_element * len(subvector))(*subvector)
1230-
# for subvector in llama_grammar_elements
1231-
# ] # type: List[Array[llama_grammar_element]]
1232-
1233-
# # Step 2: Get pointer of each array
1234-
# element_array_pointers = [
1235-
# cast(subarray, llama_grammar_element_p) for subarray in element_arrays
1236-
# ] # type: List[llama_grammar_element_p]
1237-
1238-
# # Step 3: Make array of these pointers and get its pointer
1239-
# return (llama_grammar_element_p * len(element_array_pointers))(
1240-
# *element_array_pointers
1241-
# )
1242-
1243-
1244-
if __name__ == "__main__":
1245-
parser = argparse.ArgumentParser(
1246-
description="Generate C++ parser from GBNF grammar"
1247-
)
1248-
parser.add_argument(
1249-
"-g",
1250-
"--grammar",
1251-
type=str,
1252-
default="./vendor/llama.cpp/grammars/json.gbnf",
1253-
help="path to GBNF grammar file",
1254-
)
1255-
1256-
args = parser.parse_args()
1257-
llama_grammar = LlamaGrammar.from_file(Path(args.grammar))
1258-
llama_grammar_ptr = llama_grammar.init_grammar()
1259-
1260-
# ----- USAGE:
1261-
# llama_cpp.llama_sample_grammar(ctx=..., candidates=..., grammar=llama_grammar_p)
1262-
# llama_cpp.llama_grammar_accept_token(ctx=..., grammar=llama_grammar_p, token=...)
1263-
1264-
# ----- SAMPLE OUTPUT:
1265-
# main grammar:
1266-
# root ::= object
1267-
# object ::= [{] ws object_11 [}] ws
1268-
# value ::= object | array | string | number | value_6 ws
1269-
# array ::= [[] ws array_15 []] ws
1270-
# string ::= ["] string_18 ["] ws
1271-
# number ::= number_19 number_25 number_29 ws
1272-
# value_6 ::= [t] [r] [u] [e] | [f] [a] [l] [s] [e] | [n] [u] [l] [l]
1273-
# ws ::= ws_31
1274-
# object_8 ::= string [:] ws value object_10
1275-
# object_9 ::= [,] ws string [:] ws value
1276-
# object_10 ::= object_9 object_10 |
1277-
# object_11 ::= object_8 |
1278-
# array_12 ::= value array_14
1279-
# array_13 ::= [,] ws value
1280-
# array_14 ::= array_13 array_14 |
1281-
# array_15 ::= array_12 |
1282-
# string_16 ::= [^"\] | [\] string_17
1283-
# string_17 ::= ["\/bfnrt] | [u] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]
1284-
# string_18 ::= string_16 string_18 |
1285-
# number_19 ::= number_20 number_21
1286-
# number_20 ::= [-] |
1287-
# number_21 ::= [0-9] | [1-9] number_22
1288-
# number_22 ::= [0-9] number_22 |
1289-
# number_23 ::= [.] number_24
1290-
# number_24 ::= [0-9] number_24 | [0-9]
1291-
# number_25 ::= number_23 |
1292-
# number_26 ::= [eE] number_27 number_28
1293-
# number_27 ::= [-+] |
1294-
# number_28 ::= [0-9] number_28 | [0-9]
1295-
# number_29 ::= number_26 |
1296-
# ws_30 ::= [ <U+0009><U+000A>] ws
1297-
# ws_31 ::= ws_30 |
1242+
)

0 commit comments

Comments
 (0)