1
1
"""C++ implementation of the llama grammar parser."""
2
2
# flake8: noqa
3
- import argparse
4
3
from pathlib import Path
5
4
import sys
6
5
from ctypes import * # type: ignore
19
18
overload ,
20
19
)
21
20
22
- import llama_cpp
21
+ from . import llama_cpp
23
22
24
23
# Type aliases
25
24
llama_grammar_element = llama_cpp .llama_grammar_element
@@ -41,11 +40,19 @@ class Sentinel:
41
40
class LlamaGrammar :
42
41
"""Keeps reference counts of all the arguments, so that they are not
43
42
garbage collected by Python."""
43
+
44
+ def __del__ (self ) -> None :
45
+ """Free the grammar pointer when the object is deleted."""
46
+ if self .grammar is not None :
47
+ llama_cpp .llama_grammar_free (self .grammar )
48
+ self .grammar = None
44
49
45
50
def __init__ (
46
51
self ,
47
52
parsed_grammar : "parse_state" ,
48
53
) -> None :
54
+ """Initialize the grammar pointer from the parsed state."""
55
+ self .parsed_grammar = parsed_grammar
49
56
grammar_rules = (
50
57
parsed_grammar .c_rules ()
51
58
) # type: std.vector[std.vector[llama_grammar_element]]
@@ -69,22 +76,25 @@ def __init__(
69
76
70
77
self .n_rules = c_size_t (grammar_rules .size ())
71
78
self .start_rule_index = c_size_t (parsed_grammar .symbol_ids .at ("root" ))
72
- self .grammar = self .init_grammar ()
79
+ self ._grammar = llama_cpp .llama_grammar_init (
80
+ self .rules , self .n_rules , self .start_rule_index
81
+ )
73
82
74
83
@classmethod
75
- def from_string (cls , grammar : str ) -> "LlamaGrammar" :
84
+ def from_string (cls , grammar : str , verbose : bool = True ) -> "LlamaGrammar" :
76
85
parsed_grammar = parse (const_char_p (grammar )) # type: parse_state
77
86
if parsed_grammar .rules .empty ():
78
87
raise ValueError (
79
88
f"{ cls .from_string .__name__ } : error parsing grammar file: parsed_grammar.rules is empty"
80
89
)
81
- print (f"{ cls .from_string .__name__ } grammar:" , file = sys .stderr )
82
- print_grammar (sys .stdout , parsed_grammar )
83
- print (file = sys .stderr )
90
+ if verbose :
91
+ print (f"{ cls .from_string .__name__ } grammar:" , file = sys .stderr )
92
+ print_grammar (sys .stdout , parsed_grammar )
93
+ print (file = sys .stderr )
84
94
return cls (parsed_grammar )
85
95
86
96
@classmethod
87
- def from_file (cls , file : Union [str , Path ]) -> "LlamaGrammar" :
97
+ def from_file (cls , file : Union [str , Path ], verbose : bool = True ) -> "LlamaGrammar" :
88
98
try :
89
99
with open (file ) as f :
90
100
grammar = f .read ()
@@ -94,14 +104,27 @@ def from_file(cls, file: Union[str, Path]) -> "LlamaGrammar":
94
104
)
95
105
96
106
if grammar :
97
- return cls .from_string (grammar )
107
+ return cls .from_string (grammar , verbose = verbose )
98
108
99
109
raise ValueError (
100
110
f"{ cls .from_file .__name__ } : error parsing grammar file: params_grammer is empty"
101
111
)
102
112
103
- def init_grammar (self ) -> llama_grammar_p :
104
- return llama_cpp .llama_grammar_init (
113
+ @property
114
+ def grammar (self ) -> llama_grammar_p :
115
+ if self ._grammar is None :
116
+ raise ValueError (
117
+ f"{ self .__class__ .__name__ } .grammar: grammar is freed"
118
+ )
119
+ return self ._grammar
120
+
121
+ @grammar .setter
122
+ def grammar (self , value : Optional [llama_grammar_p ]) -> None :
123
+ self ._grammar = value
124
+
125
+ def reset (self ) -> None :
126
+ llama_cpp .llama_grammar_free (self .grammar )
127
+ self .grammar = llama_cpp .llama_grammar_init (
105
128
self .rules , self .n_rules , self .start_rule_index
106
129
)
107
130
@@ -1216,82 +1239,4 @@ def print_grammar(file: TextIO, state: parse_state) -> None:
1216
1239
print (
1217
1240
f"{ print_grammar .__name__ } : error printing grammar: { err } " ,
1218
1241
file = sys .stderr ,
1219
- )
1220
-
1221
-
1222
- # def convert_to_rules(
1223
- # llama_grammar_elements: std.vector[std.vector[llama_grammar_element]],
1224
- # ) -> Array[llama_grammar_element_p]:
1225
- # """Make an Array object that is used for `llama_grammer_init`"""
1226
-
1227
- # # Step 1: Convert each list to llama_grammar_element array and get pointer
1228
- # element_arrays = [
1229
- # (llama_grammar_element * len(subvector))(*subvector)
1230
- # for subvector in llama_grammar_elements
1231
- # ] # type: List[Array[llama_grammar_element]]
1232
-
1233
- # # Step 2: Get pointer of each array
1234
- # element_array_pointers = [
1235
- # cast(subarray, llama_grammar_element_p) for subarray in element_arrays
1236
- # ] # type: List[llama_grammar_element_p]
1237
-
1238
- # # Step 3: Make array of these pointers and get its pointer
1239
- # return (llama_grammar_element_p * len(element_array_pointers))(
1240
- # *element_array_pointers
1241
- # )
1242
-
1243
-
1244
- if __name__ == "__main__" :
1245
- parser = argparse .ArgumentParser (
1246
- description = "Generate C++ parser from GBNF grammar"
1247
- )
1248
- parser .add_argument (
1249
- "-g" ,
1250
- "--grammar" ,
1251
- type = str ,
1252
- default = "./vendor/llama.cpp/grammars/json.gbnf" ,
1253
- help = "path to GBNF grammar file" ,
1254
- )
1255
-
1256
- args = parser .parse_args ()
1257
- llama_grammar = LlamaGrammar .from_file (Path (args .grammar ))
1258
- llama_grammar_ptr = llama_grammar .init_grammar ()
1259
-
1260
- # ----- USAGE:
1261
- # llama_cpp.llama_sample_grammar(ctx=..., candidates=..., grammar=llama_grammar_p)
1262
- # llama_cpp.llama_grammar_accept_token(ctx=..., grammar=llama_grammar_p, token=...)
1263
-
1264
- # ----- SAMPLE OUTPUT:
1265
- # main grammar:
1266
- # root ::= object
1267
- # object ::= [{] ws object_11 [}] ws
1268
- # value ::= object | array | string | number | value_6 ws
1269
- # array ::= [[] ws array_15 []] ws
1270
- # string ::= ["] string_18 ["] ws
1271
- # number ::= number_19 number_25 number_29 ws
1272
- # value_6 ::= [t] [r] [u] [e] | [f] [a] [l] [s] [e] | [n] [u] [l] [l]
1273
- # ws ::= ws_31
1274
- # object_8 ::= string [:] ws value object_10
1275
- # object_9 ::= [,] ws string [:] ws value
1276
- # object_10 ::= object_9 object_10 |
1277
- # object_11 ::= object_8 |
1278
- # array_12 ::= value array_14
1279
- # array_13 ::= [,] ws value
1280
- # array_14 ::= array_13 array_14 |
1281
- # array_15 ::= array_12 |
1282
- # string_16 ::= [^"\] | [\] string_17
1283
- # string_17 ::= ["\/bfnrt] | [u] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]
1284
- # string_18 ::= string_16 string_18 |
1285
- # number_19 ::= number_20 number_21
1286
- # number_20 ::= [-] |
1287
- # number_21 ::= [0-9] | [1-9] number_22
1288
- # number_22 ::= [0-9] number_22 |
1289
- # number_23 ::= [.] number_24
1290
- # number_24 ::= [0-9] number_24 | [0-9]
1291
- # number_25 ::= number_23 |
1292
- # number_26 ::= [eE] number_27 number_28
1293
- # number_27 ::= [-+] |
1294
- # number_28 ::= [0-9] number_28 | [0-9]
1295
- # number_29 ::= number_26 |
1296
- # ws_30 ::= [ <U+0009><U+000A>] ws
1297
- # ws_31 ::= ws_30 |
1242
+ )
0 commit comments