Skip to content

Commit 007e26a

Browse files
committed
Added context free grammar constraints
1 parent 08737ef commit 007e26a

File tree

6 files changed

+1042
-4
lines changed

6 files changed

+1042
-4
lines changed

examples/common.cpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,13 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
231231
break;
232232
}
233233
params.mirostat_tau = std::stof(argv[i]);
234-
} else if (arg == "-b" || arg == "--batch-size") {
234+
} else if (arg == "--grammar") {
235+
if (++i >= argc) {
236+
invalid_param = true;
237+
break;
238+
}
239+
params.token_grammar_path = argv[i];
240+
} else if (arg == "-b" || arg == "--batch_size") {
235241
if (++i >= argc) {
236242
invalid_param = true;
237243
break;

examples/common.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@ struct gpt_params {
4848
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
4949
std::string prompt = "";
5050
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
51-
std::string input_prefix = ""; // string to prefix user inputs with
52-
std::string input_suffix = ""; // string to suffix user inputs with
51+
std::string token_grammar_path = ""; // path to file containing serialized token validator
52+
std::string input_prefix = ""; // string to prefix user inputs with
53+
std::string input_suffix = ""; // string to suffix user inputs with
5354
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
5455

5556
std::string lora_adapter = ""; // lora adapter path

examples/main/main.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,15 @@ int main(int argc, char ** argv) {
117117
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
118118
}
119119

120+
121+
// load input from params.validator_path
122+
std::string token_grammar_path = params.token_grammar_path;
123+
void* grammar = nullptr;
124+
if (!token_grammar_path.empty()) {
125+
fprintf(stderr, "%s: attempting to parse token grammar from '%s'\n", __func__, token_grammar_path.c_str());
126+
grammar = llama_load_token_grammar_from_path(token_grammar_path.c_str());
127+
}
128+
120129
// determine the maximum memory usage needed to do inference for the given n_batch and n_predict parameters
121130
// uncomment the "used_mem" line in llama.cpp to see the results
122131
if (params.mem_test) {
@@ -420,6 +429,7 @@ int main(int argc, char ** argv) {
420429
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
421430

422431
// Apply penalties
432+
llama_grammar_penalty(ctx, &candidates_p, grammar);
423433
float nl_logit = logits[llama_token_nl()];
424434
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
425435
llama_sample_repetition_penalty(ctx, &candidates_p,
@@ -459,6 +469,7 @@ int main(int argc, char ** argv) {
459469

460470
last_n_tokens.erase(last_n_tokens.begin());
461471
last_n_tokens.push_back(id);
472+
llama_grammar_accept_token(ctx, id, grammar);
462473
}
463474

464475
// replace end of text token with newline token when in interactive mode

0 commit comments

Comments
 (0)