Skip to content

Commit 73904f6

Browse files
Sebby37olexiyb
authored andcommitted
main : Add ChatML functionality to main example (ggml-org#4046)
Co-authored-by: Sebastian Cramond <[email protected]>
1 parent 9195f0c commit 73904f6

File tree

4 files changed

+42
-5
lines changed

4 files changed

+42
-5
lines changed

common/common.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
491491
params.interactive_first = true;
492492
} else if (arg == "-ins" || arg == "--instruct") {
493493
params.instruct = true;
494+
} else if (arg == "-cml" || arg == "--chatml") {
495+
params.chatml = true;
494496
} else if (arg == "--infill") {
495497
params.infill = true;
496498
} else if (arg == "--multiline-input") {
@@ -730,6 +732,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
730732
printf(" -i, --interactive run in interactive mode\n");
731733
printf(" --interactive-first run in interactive mode and wait for input right away\n");
732734
printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n");
735+
printf(" -cml, --chatml run in chatml mode (use with ChatML-compatible models)\n");
733736
printf(" --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n");
734737
printf(" -r PROMPT, --reverse-prompt PROMPT\n");
735738
printf(" halt generation at PROMPT, return control in interactive mode\n");

common/common.h

+1
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ struct gpt_params {
102102
bool random_prompt = false; // do not randomize prompt if none provided
103103
bool use_color = false; // use color to distinguish generations and inputs
104104
bool interactive = false; // interactive mode
105+
bool chatml = false; // chatml mode (used for models trained on chatml syntax)
105106
bool prompt_cache_all = false; // save user input and generations to prompt cache
106107
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it
107108

examples/infill/infill.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,13 @@ int main(int argc, char ** argv) {
146146

147147
return 0;
148148
}
149+
if (params.chatml) {
150+
printf("\n************\n");
151+
printf("%s: please use the 'main' tool for chatml mode\n", __func__);
152+
printf("************\n\n");
153+
154+
return 0;
155+
}
149156
if (!params.antiprompt.empty()) {
150157
printf("\n************\n");
151158
printf("%s: please use the 'main' tool for antiprompt mode\n", __func__);

examples/main/main.cpp

+31-5
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,11 @@ int main(int argc, char ** argv) {
234234

235235
std::vector<llama_token> embd_inp;
236236

237-
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
237+
if (params.interactive_first || params.instruct || params.chatml || !params.prompt.empty() || session_tokens.empty()) {
238238
LOG("tokenize the prompt\n");
239+
if (params.chatml) {
240+
params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>";
241+
}
239242
embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
240243
} else {
241244
LOG("use session tokens\n");
@@ -313,7 +316,7 @@ int main(int argc, char ** argv) {
313316
}
314317

315318
// number of tokens to keep when resetting context
316-
if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct) {
319+
if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct || params.chatml) {
317320
params.n_keep = (int)embd_inp.size();
318321
}
319322

@@ -324,11 +327,23 @@ int main(int argc, char ** argv) {
324327
LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str());
325328
LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str());
326329

330+
// chatml prefix & suffix
331+
const auto cml_pfx = ::llama_tokenize(ctx, "\n<|im_start|>user\n", add_bos, true);
332+
const auto cml_sfx = ::llama_tokenize(ctx, "<|im_end|>\n<|im_start|>assistant\n", false, true);
333+
334+
LOG("cml_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, cml_pfx).c_str());
335+
LOG("cml_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, cml_sfx).c_str());
336+
327337
// in instruct mode, we inject a prefix and a suffix to each input by the user
328338
if (params.instruct) {
329339
params.interactive_first = true;
330340
params.antiprompt.push_back("### Instruction:\n\n");
331341
}
342+
// similar for chatml mode
343+
else if (params.chatml) {
344+
params.interactive_first = true;
345+
params.antiprompt.push_back("<|im_start|>user\n");
346+
}
332347

333348
// enable interactive mode if interactive start is specified
334349
if (params.interactive_first) {
@@ -705,15 +720,15 @@ int main(int argc, char ** argv) {
705720

706721
is_interacting = true;
707722
printf("\n");
708-
} else if (params.instruct) {
723+
} else if (params.instruct || params.chatml) {
709724
is_interacting = true;
710725
}
711726
}
712727

713728
if (n_past > 0 && is_interacting) {
714729
LOG("waiting for user input\n");
715730

716-
if (params.instruct) {
731+
if (params.instruct || params.chatml) {
717732
printf("\n> ");
718733
}
719734

@@ -760,6 +775,12 @@ int main(int argc, char ** argv) {
760775
n_consumed = embd_inp.size();
761776
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
762777
}
778+
// chatml mode: insert user chat prefix
779+
if (params.chatml && !is_antiprompt) {
780+
LOG("inserting chatml prefix\n");
781+
n_consumed = embd_inp.size();
782+
embd_inp.insert(embd_inp.end(), cml_pfx.begin(), cml_pfx.end());
783+
}
763784
if (params.escape) {
764785
process_escapes(buffer);
765786
}
@@ -778,6 +799,11 @@ int main(int argc, char ** argv) {
778799
LOG("inserting instruction suffix\n");
779800
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
780801
}
802+
// chatml mode: insert assistant chat suffix
803+
if (params.chatml) {
804+
LOG("inserting chatml suffix\n");
805+
embd_inp.insert(embd_inp.end(), cml_sfx.begin(), cml_sfx.end());
806+
}
781807

782808
for (size_t i = original_size; i < embd_inp.size(); ++i) {
783809
const llama_token token = embd_inp[i];
@@ -803,7 +829,7 @@ int main(int argc, char ** argv) {
803829
}
804830

805831
// end of text token
806-
if (!embd.empty() && embd.back() == llama_token_eos(model) && !(params.instruct || params.interactive)) {
832+
if (!embd.empty() && embd.back() == llama_token_eos(model) && !(params.instruct || params.interactive || params.chatml)) {
807833
LOG_TEE(" [end of text]\n");
808834
break;
809835
}

0 commit comments

Comments
 (0)