Skip to content

Commit 2f5a5a6

Browse files
committed
talk-llama : use llama_decode instead of llama_eval
1 parent 8e409d1 commit 2f5a5a6

File tree

1 file changed

+32
-4
lines changed

1 file changed

+32
-4
lines changed

examples/talk-llama/talk-llama.cpp

+32-4
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,8 @@ int main(int argc, char ** argv) {
391391

392392
prompt_llama = ::replace(prompt_llama, "{4}", chat_symb);
393393

394+
llama_batch batch = llama_batch_init(llama_n_ctx(ctx_llama), 0, 1);
395+
394396
// init session
395397
std::string path_session = params.path_session;
396398
std::vector<llama_token> session_tokens;
@@ -426,8 +428,21 @@ int main(int argc, char ** argv) {
426428
printf("\n");
427429
printf("%s : initializing - please wait ...\n", __func__);
428430

429-
if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0)) {
430-
fprintf(stderr, "%s : failed to eval\n", __func__);
431+
// prepare batch
432+
{
433+
batch.n_tokens = embd_inp.size();
434+
435+
for (int i = 0; i < batch.n_tokens; i++) {
436+
batch.token[i] = embd_inp[i];
437+
batch.pos[i] = i;
438+
batch.n_seq_id[i] = 1;
439+
batch.seq_id[i][0] = 0;
440+
batch.logits[i] = i == batch.n_tokens - 1;
441+
}
442+
}
443+
444+
if (llama_decode(ctx_llama, batch)) {
445+
fprintf(stderr, "%s : failed to decode\n", __func__);
431446
return 1;
432447
}
433448

@@ -647,8 +662,21 @@ int main(int argc, char ** argv) {
647662
n_session_consumed = session_tokens.size();
648663
}
649664

650-
if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past)) {
651-
fprintf(stderr, "%s : failed to eval\n", __func__);
665+
// prepare batch
666+
{
667+
batch.n_tokens = embd.size();
668+
669+
for (int i = 0; i < batch.n_tokens; i++) {
670+
batch.token[i] = embd[i];
671+
batch.pos[i] = n_past + i;
672+
batch.n_seq_id[i] = 1;
673+
batch.seq_id[i][0] = 0;
674+
batch.logits[i] = i == batch.n_tokens - 1;
675+
}
676+
}
677+
678+
if (llama_decode(ctx_llama, batch)) {
679+
fprintf(stderr, "%s : failed to decode\n", __func__);
652680
return 1;
653681
}
654682
}

0 commit comments

Comments
 (0)