Skip to content

Commit 84e09a7

Browse files
ejonesSlyEchoggerganov
authored
llama : add grammar-based sampling (#1773)
* llama, main : constrain sampling to grammar * allow loading grammar from file * fix whitespace errors * handle & print parser errors * add comments to grammar syntax and allow newlines where unambiguous * add missing include * support alternates in root rule * fix bugs with empty token and EOS * adjust JSON grammar * remove swp file * rewrite ternary expressions Co-authored-by: Henri Vasserman <[email protected]> * use struct for grammar elements and add Unicode support * add unicode escapes * add inverse char ranges * only sample full tokens (no peeking or truncation) * llama : minor style changes blindly applied in online editor - hopefully I didn't break something * update help text * add warning message if EOS is disabled --------- Co-authored-by: Henri Vasserman <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 2f9cf97 commit 84e09a7

14 files changed

+977
-1
lines changed

Diff for: Makefile

+4-1
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,9 @@ llama.o: llama.cpp ggml.h ggml-cuda.h ggml-metal.h llama.h llama-util.h
323323
common.o: examples/common.cpp examples/common.h
324324
$(CXX) $(CXXFLAGS) -c $< -o $@
325325

326+
grammar-parser.o: examples/grammar-parser.cpp examples/grammar-parser.h
327+
$(CXX) $(CXXFLAGS) -c $< -o $@
328+
326329
libllama.so: llama.o ggml.o $(OBJS)
327330
$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)
328331

@@ -333,7 +336,7 @@ clean:
333336
# Examples
334337
#
335338

336-
main: examples/main/main.cpp build-info.h ggml.o llama.o common.o $(OBJS)
339+
main: examples/main/main.cpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS)
337340
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
338341
@echo
339342
@echo '==== Run ./main -h for help. ===='

Diff for: examples/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ set(TARGET common)
1313
add_library(${TARGET} OBJECT
1414
common.h
1515
common.cpp
16+
grammar-parser.h
17+
grammar-parser.cpp
1618
)
1719

1820
if (BUILD_SHARED_LIBS)

Diff for: examples/common.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,28 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
438438
break;
439439
}
440440
params.input_suffix = argv[i];
441+
} else if (arg == "--grammar") {
442+
if (++i >= argc) {
443+
invalid_param = true;
444+
break;
445+
}
446+
params.grammar = argv[i];
447+
} else if (arg == "--grammar-file") {
448+
if (++i >= argc) {
449+
invalid_param = true;
450+
break;
451+
}
452+
std::ifstream file(argv[i]);
453+
if (!file) {
454+
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
455+
invalid_param = true;
456+
break;
457+
}
458+
std::copy(
459+
std::istreambuf_iterator<char>(file),
460+
std::istreambuf_iterator<char>(),
461+
std::back_inserter(params.grammar)
462+
);
441463
} else {
442464
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
443465
gpt_print_usage(argc, argv, default_params);
@@ -514,6 +536,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
514536
fprintf(stdout, " modifies the likelihood of token appearing in the completion,\n");
515537
fprintf(stdout, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
516538
fprintf(stdout, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
539+
fprintf(stdout, " --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir)\n");
540+
fprintf(stdout, " --grammar-file FNAME file to read grammar from\n");
517541
fprintf(stdout, " --cfg-negative-prompt PROMPT \n");
518542
fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n");
519543
fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);

Diff for: examples/common.h

+1
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ struct gpt_params {
6363
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
6464
std::string input_prefix = ""; // string to prefix user inputs with
6565
std::string input_suffix = ""; // string to suffix user inputs with
66+
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
6667
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
6768

6869
std::string lora_adapter = ""; // lora adapter path

0 commit comments

Comments
 (0)