@@ -391,6 +391,8 @@ int main(int argc, char ** argv) {
391
391
392
392
prompt_llama = ::replace (prompt_llama, " {4}" , chat_symb);
393
393
394
+ llama_batch batch = llama_batch_init (llama_n_ctx (ctx_llama), 0 , 1 );
395
+
394
396
// init session
395
397
std::string path_session = params.path_session ;
396
398
std::vector<llama_token> session_tokens;
@@ -426,8 +428,21 @@ int main(int argc, char ** argv) {
426
428
printf (" \n " );
427
429
printf (" %s : initializing - please wait ...\n " , __func__);
428
430
429
- if (llama_eval (ctx_llama, embd_inp.data (), embd_inp.size (), 0 )) {
430
- fprintf (stderr, " %s : failed to eval\n " , __func__);
431
+ // prepare batch
432
+ {
433
+ batch.n_tokens = embd_inp.size ();
434
+
435
+ for (int i = 0 ; i < batch.n_tokens ; i++) {
436
+ batch.token [i] = embd_inp[i];
437
+ batch.pos [i] = i;
438
+ batch.n_seq_id [i] = 1 ;
439
+ batch.seq_id [i][0 ] = 0 ;
440
+ batch.logits [i] = i == batch.n_tokens - 1 ;
441
+ }
442
+ }
443
+
444
+ if (llama_decode (ctx_llama, batch)) {
445
+ fprintf (stderr, " %s : failed to decode\n " , __func__);
431
446
return 1 ;
432
447
}
433
448
@@ -647,8 +662,21 @@ int main(int argc, char ** argv) {
647
662
n_session_consumed = session_tokens.size ();
648
663
}
649
664
650
- if (llama_eval (ctx_llama, embd.data (), embd.size (), n_past)) {
651
- fprintf (stderr, " %s : failed to eval\n " , __func__);
665
+ // prepare batch
666
+ {
667
+ batch.n_tokens = embd.size ();
668
+
669
+ for (int i = 0 ; i < batch.n_tokens ; i++) {
670
+ batch.token [i] = embd[i];
671
+ batch.pos [i] = n_past + i;
672
+ batch.n_seq_id [i] = 1 ;
673
+ batch.seq_id [i][0 ] = 0 ;
674
+ batch.logits [i] = i == batch.n_tokens - 1 ;
675
+ }
676
+ }
677
+
678
+ if (llama_decode (ctx_llama, batch)) {
679
+ fprintf (stderr, " %s : failed to decode\n " , __func__);
652
680
return 1 ;
653
681
}
654
682
}
0 commit comments