@@ -117,6 +117,15 @@ int main(int argc, char ** argv) {
117
117
params.n_threads , std::thread::hardware_concurrency (), llama_print_system_info ());
118
118
}
119
119
120
+
121
+ // load input from params.validator_path
122
+ std::string token_grammar_path = params.token_grammar_path ;
123
+ void * grammar = nullptr ;
124
+ if (!token_grammar_path.empty ()) {
125
+ fprintf (stderr, " %s: attempting to parse token grammar from '%s'\n " , __func__, token_grammar_path.c_str ());
126
+ grammar = llama_load_token_grammar_from_path (token_grammar_path.c_str ());
127
+ }
128
+
120
129
// determine the maximum memory usage needed to do inference for the given n_batch and n_predict parameters
121
130
// uncomment the "used_mem" line in llama.cpp to see the results
122
131
if (params.mem_test ) {
@@ -420,6 +429,7 @@ int main(int argc, char ** argv) {
420
429
llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
421
430
422
431
// Apply penalties
432
+ llama_grammar_penalty (ctx, &candidates_p, grammar);
423
433
float nl_logit = logits[llama_token_nl ()];
424
434
auto last_n_repeat = std::min (std::min ((int )last_n_tokens.size (), repeat_last_n), n_ctx);
425
435
llama_sample_repetition_penalty (ctx, &candidates_p,
@@ -459,6 +469,7 @@ int main(int argc, char ** argv) {
459
469
460
470
last_n_tokens.erase (last_n_tokens.begin ());
461
471
last_n_tokens.push_back (id);
472
+ llama_grammar_accept_token (ctx, id, grammar);
462
473
}
463
474
464
475
// replace end of text token with newline token when in interactive mode
0 commit comments