@@ -234,8 +234,11 @@ int main(int argc, char ** argv) {
234
234
235
235
std::vector<llama_token> embd_inp;
236
236
237
- if (params.interactive_first || params.instruct || !params.prompt .empty () || session_tokens.empty ()) {
237
+ if (params.interactive_first || params.instruct || params. chatml || !params.prompt .empty () || session_tokens.empty ()) {
238
238
LOG (" tokenize the prompt\n " );
239
+ if (params.chatml ) {
240
+ params.prompt = " <|im_start|>system\n " + params.prompt + " <|im_end|>" ;
241
+ }
239
242
embd_inp = ::llama_tokenize (ctx, params.prompt , add_bos, true );
240
243
} else {
241
244
LOG (" use session tokens\n " );
@@ -313,7 +316,7 @@ int main(int argc, char ** argv) {
313
316
}
314
317
315
318
// number of tokens to keep when resetting context
316
- if (params.n_keep < 0 || params.n_keep > (int ) embd_inp.size () || params.instruct ) {
319
+ if (params.n_keep < 0 || params.n_keep > (int ) embd_inp.size () || params.instruct || params. chatml ) {
317
320
params.n_keep = (int )embd_inp.size ();
318
321
}
319
322
@@ -324,11 +327,23 @@ int main(int argc, char ** argv) {
324
327
LOG (" inp_pfx: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, inp_pfx).c_str ());
325
328
LOG (" inp_sfx: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, inp_sfx).c_str ());
326
329
330
+ // chatml prefix & suffix
331
+ const auto cml_pfx = ::llama_tokenize (ctx, " \n <|im_start|>user\n " , add_bos, true );
332
+ const auto cml_sfx = ::llama_tokenize (ctx, " <|im_end|>\n <|im_start|>assistant\n " , false , true );
333
+
334
+ LOG (" cml_pfx: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, cml_pfx).c_str ());
335
+ LOG (" cml_sfx: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, cml_sfx).c_str ());
336
+
327
337
// in instruct mode, we inject a prefix and a suffix to each input by the user
328
338
if (params.instruct ) {
329
339
params.interactive_first = true ;
330
340
params.antiprompt .push_back (" ### Instruction:\n\n " );
331
341
}
342
+ // similar for chatml mode
343
+ else if (params.chatml ) {
344
+ params.interactive_first = true ;
345
+ params.antiprompt .push_back (" <|im_start|>user\n " );
346
+ }
332
347
333
348
// enable interactive mode if interactive start is specified
334
349
if (params.interactive_first ) {
@@ -705,15 +720,15 @@ int main(int argc, char ** argv) {
705
720
706
721
is_interacting = true ;
707
722
printf (" \n " );
708
- } else if (params.instruct ) {
723
+ } else if (params.instruct || params. chatml ) {
709
724
is_interacting = true ;
710
725
}
711
726
}
712
727
713
728
if (n_past > 0 && is_interacting) {
714
729
LOG (" waiting for user input\n " );
715
730
716
- if (params.instruct ) {
731
+ if (params.instruct || params. chatml ) {
717
732
printf (" \n > " );
718
733
}
719
734
@@ -760,6 +775,12 @@ int main(int argc, char ** argv) {
760
775
n_consumed = embd_inp.size ();
761
776
embd_inp.insert (embd_inp.end (), inp_pfx.begin (), inp_pfx.end ());
762
777
}
778
+ // chatml mode: insert user chat prefix
779
+ if (params.chatml && !is_antiprompt) {
780
+ LOG (" inserting chatml prefix\n " );
781
+ n_consumed = embd_inp.size ();
782
+ embd_inp.insert (embd_inp.end (), cml_pfx.begin (), cml_pfx.end ());
783
+ }
763
784
if (params.escape ) {
764
785
process_escapes (buffer);
765
786
}
@@ -778,6 +799,11 @@ int main(int argc, char ** argv) {
778
799
LOG (" inserting instruction suffix\n " );
779
800
embd_inp.insert (embd_inp.end (), inp_sfx.begin (), inp_sfx.end ());
780
801
}
802
+ // chatml mode: insert assistant chat suffix
803
+ if (params.chatml ) {
804
+ LOG (" inserting chatml suffix\n " );
805
+ embd_inp.insert (embd_inp.end (), cml_sfx.begin (), cml_sfx.end ());
806
+ }
781
807
782
808
for (size_t i = original_size; i < embd_inp.size (); ++i) {
783
809
const llama_token token = embd_inp[i];
@@ -803,7 +829,7 @@ int main(int argc, char ** argv) {
803
829
}
804
830
805
831
// end of text token
806
- if (!embd.empty () && embd.back () == llama_token_eos (model) && !(params.instruct || params.interactive )) {
832
+ if (!embd.empty () && embd.back () == llama_token_eos (model) && !(params.instruct || params.interactive || params. chatml )) {
807
833
LOG_TEE (" [end of text]\n " );
808
834
break ;
809
835
}
0 commit comments