@@ -8432,74 +8432,33 @@ static enum ggml_status llama_graph_compute(
8432
8432
return status;
8433
8433
}
8434
8434
8435
- // decode a batch of tokens by evaluating the transformer
8436
- // in case of unsuccessful decoding (error or warning),
8437
- // the kv_cache state will be returned to its original state
8438
- // (for non-recurrent models) or cleaned (for recurrent models)
8439
- //
8440
- // - lctx: llama context
8441
- // - batch: batch to evaluate
8442
- //
8443
- // return 0 on success
8444
- // return positive int on warning
8445
- // return negative int on error
8446
- //
8447
- static int llama_decode_impl (
8448
- llama_context & lctx,
8449
- llama_batch inp_batch) {
8450
-
8451
- lctx.is_encoding = false ;
8452
-
8453
- if (inp_batch.n_tokens == 0 ) {
8454
- LLAMA_LOG_ERROR (" %s: n_tokens == 0\n " , __func__);
8455
- return -1 ;
8456
- }
8457
-
8458
- // temporary allocate memory for the input batch if needed
8459
- llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : lctx.kv_self .max_pos () + 1 );
8460
-
8461
- const llama_batch & batch = batch_allocr.batch ;
8462
- const uint32_t n_tokens_all = batch.n_tokens ;
8463
-
8435
+ static int llama_prepare_sbatch (
8436
+ llama_context & lctx,
8437
+ const llama_batch & batch,
8438
+ uint32_t & n_outputs) {
8464
8439
const auto & model = lctx.model ;
8465
- const auto & vocab = model.vocab ;
8466
8440
const auto & hparams = model.hparams ;
8467
8441
const auto & cparams = lctx.cparams ;
8468
8442
8469
- GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
8443
+ const uint32_t n_tokens_all = batch.n_tokens ;
8444
+ const int64_t n_embd = hparams.n_embd ;
8445
+
8446
+ // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
8447
+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
8470
8448
8449
+ GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
8471
8450
if (batch.token ) {
8472
8451
for (uint32_t i = 0 ; i < n_tokens_all; ++i) {
8473
- if (batch.token [i] < 0 || ( uint32_t ) batch.token [i] >= model.vocab .n_tokens ()) {
8452
+ if (batch.token [i] < 0 || uint32_t ( batch.token [i]) >= model.vocab .n_tokens ()) {
8474
8453
LLAMA_LOG_ERROR (" %s: invalid token[%d] = %d\n " , __func__, i, batch.token [i]);
8475
8454
return -1 ;
8476
8455
}
8477
8456
}
8478
8457
}
8479
-
8480
8458
GGML_ASSERT (n_tokens_all <= cparams.n_batch );
8481
-
8482
8459
GGML_ASSERT ((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && " non-causal attention requires n_ubatch >= n_tokens" );
8483
8460
8484
- if (lctx.t_compute_start_us == 0 ) {
8485
- lctx.t_compute_start_us = ggml_time_us ();
8486
- }
8487
8461
lctx.n_queued_tokens += n_tokens_all;
8488
-
8489
- auto & kv_self = lctx.kv_self ;
8490
- llama_kv_slot_restorer kv_slot_restorer (kv_self);
8491
-
8492
- const int64_t n_embd = hparams.n_embd ;
8493
- const int64_t n_vocab = vocab.n_tokens ();
8494
-
8495
- uint32_t n_outputs = 0 ;
8496
- uint32_t n_outputs_prev = 0 ;
8497
-
8498
- const auto n_ubatch = cparams.n_ubatch ;
8499
-
8500
- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
8501
- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
8502
-
8503
8462
lctx.embd_seq .clear ();
8504
8463
8505
8464
// count outputs
@@ -8515,7 +8474,7 @@ static int llama_decode_impl(
8515
8474
}
8516
8475
8517
8476
lctx.sbatch .from_batch (batch, n_embd,
8518
- /* simple_split */ !kv_self.recurrent ,
8477
+ /* simple_split */ !lctx. kv_self .recurrent ,
8519
8478
/* logits_all */ n_outputs == n_tokens_all);
8520
8479
8521
8480
// reserve output buffer
@@ -8524,70 +8483,148 @@ static int llama_decode_impl(
8524
8483
return -2 ;
8525
8484
};
8526
8485
8527
- while (lctx.sbatch .n_tokens > 0 ) {
8528
- llama_ubatch ubatch;
8529
- if (kv_self.recurrent ) {
8530
- if (embd_pooled) {
8531
- // Pooled embeddings cannot be split across ubatches (yet)
8532
- ubatch = lctx.sbatch .split_seq (n_ubatch);
8533
- } else {
8534
- // recurrent model architectures are easier to implement
8535
- // with equal-length sequences
8536
- ubatch = lctx.sbatch .split_equal (n_ubatch);
8537
- }
8486
+ return 0 ;
8487
+ }
8488
+
8489
+ static int llama_prepare_ubatch (
8490
+ llama_context & lctx,
8491
+ llama_kv_slot_restorer & kv_slot_restorer,
8492
+ llama_ubatch & ubatch,
8493
+ const uint32_t n_outputs,
8494
+ const uint32_t n_tokens_all) {
8495
+ GGML_ASSERT (lctx.sbatch .n_tokens > 0 );
8496
+
8497
+ auto & kv_self = lctx.kv_self ;
8498
+ const auto & cparams = lctx.cparams ;
8499
+ const auto & hparams = lctx.model .hparams ;
8500
+
8501
+ // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
8502
+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
8503
+
8504
+ if (lctx.kv_self .recurrent ) {
8505
+ if (embd_pooled) {
8506
+ // Pooled embeddings cannot be split across ubatches (yet)
8507
+ ubatch = lctx.sbatch .split_seq (cparams.n_ubatch );
8538
8508
} else {
8539
- ubatch = lctx.sbatch .split_simple (n_ubatch);
8509
+ // recurrent model architectures are easier to implement
8510
+ // with equal-length sequences
8511
+ ubatch = lctx.sbatch .split_equal (cparams.n_ubatch );
8540
8512
}
8541
- const uint32_t n_tokens = ubatch.n_tokens ;
8513
+ } else {
8514
+ ubatch = lctx.sbatch .split_simple (cparams.n_ubatch );
8515
+ }
8542
8516
8543
- // count the outputs in this u_batch
8544
- {
8545
- int32_t n_outputs_new = 0 ;
8517
+ // count the outputs in this u_batch
8518
+ {
8519
+ int32_t n_outputs_new = 0 ;
8546
8520
8547
- if (n_outputs == n_tokens_all) {
8548
- n_outputs_new = n_tokens;
8549
- } else {
8550
- GGML_ASSERT (ubatch.output );
8551
- for (uint32_t i = 0 ; i < n_tokens; i++) {
8552
- n_outputs_new += (int32_t ) (ubatch.output [i] != 0 );
8553
- }
8521
+ if (n_outputs == n_tokens_all) {
8522
+ n_outputs_new = ubatch.n_tokens ;
8523
+ } else {
8524
+ GGML_ASSERT (ubatch.output );
8525
+ for (uint32_t i = 0 ; i < ubatch.n_tokens ; i++) {
8526
+ n_outputs_new += int32_t (ubatch.output [i] != 0 );
8554
8527
}
8528
+ }
8529
+
8530
+ // needs to happen before the graph is built
8531
+ lctx.n_outputs = n_outputs_new;
8532
+ }
8533
+
8534
+ // non-causal masks do not use the KV cache
8535
+ if (hparams.causal_attn ) {
8536
+ llama_kv_cache_update (&lctx);
8555
8537
8556
- // needs to happen before the graph is built
8557
- lctx.n_outputs = n_outputs_new;
8538
+ // if we have enough unused cells before the current head ->
8539
+ // better to start searching from the beginning of the cache, hoping to fill it
8540
+ if (kv_self.head > kv_self.used + 2 *ubatch.n_tokens ) {
8541
+ kv_self.head = 0 ;
8558
8542
}
8559
8543
8560
- int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch ;
8561
- ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch ;
8544
+ const auto slot = llama_kv_cache_find_slot (kv_self, ubatch);
8545
+ if (!slot) {
8546
+ return 1 ;
8547
+ }
8548
+ kv_slot_restorer.save (slot);
8562
8549
8563
- GGML_ASSERT (n_threads > 0 );
8550
+ if (!kv_self.recurrent ) {
8551
+ // a heuristic, to avoid attending the full cache if it is not yet utilized
8552
+ // after enough generations, the benefit from this heuristic disappears
8553
+ // if we start defragmenting the cache, the benefit from this will be more important
8554
+ const uint32_t pad = llama_kv_cache_get_padding (cparams);
8555
+ kv_self.n = std::min (kv_self.size , std::max (pad, GGML_PAD (llama_kv_cache_cell_max (kv_self), pad)));
8556
+ // kv_self.n = llama_kv_cache_cell_max(kv_self);
8557
+ }
8558
+ }
8564
8559
8565
- // non-causal masks do not use the KV cache
8566
- if (hparams.causal_attn ) {
8567
- llama_kv_cache_update (&lctx);
8560
+ return 0 ;
8561
+ }
8568
8562
8569
- // if we have enough unused cells before the current head ->
8570
- // better to start searching from the beginning of the cache, hoping to fill it
8571
- if (kv_self.head > kv_self.used + 2 *n_tokens) {
8572
- kv_self.head = 0 ;
8573
- }
8563
+ // decode a batch of tokens by evaluating the transformer
8564
+ // in case of unsuccessful decoding (error or warning),
8565
+ // the kv_cache state will be returned to its original state
8566
+ // (for non-recurrent models) or cleaned (for recurrent models)
8567
+ //
8568
+ // - lctx: llama context
8569
+ // - inp_batch: batch to evaluate
8570
+ //
8571
+ // return 0 on success
8572
+ // return positive int on warning
8573
+ // return negative int on error
8574
+ //
8575
+ static int llama_decode_impl (
8576
+ llama_context & lctx,
8577
+ llama_batch inp_batch) {
8574
8578
8575
- const auto slot = llama_kv_cache_find_slot (kv_self, ubatch);
8576
- if (!slot) {
8577
- return 1 ;
8578
- }
8579
- kv_slot_restorer.save (slot);
8579
+ lctx.is_encoding = false ;
8580
8580
8581
- if (!kv_self.recurrent ) {
8582
- // a heuristic, to avoid attending the full cache if it is not yet utilized
8583
- // after enough generations, the benefit from this heuristic disappears
8584
- // if we start defragmenting the cache, the benefit from this will be more important
8585
- const uint32_t pad = llama_kv_cache_get_padding (cparams);
8586
- kv_self.n = std::min (kv_self.size , std::max (pad, GGML_PAD (llama_kv_cache_cell_max (kv_self), pad)));
8587
- // kv_self.n = llama_kv_cache_cell_max(kv_self);
8581
+ if (inp_batch.n_tokens == 0 ) {
8582
+ LLAMA_LOG_ERROR (" %s: n_tokens == 0\n " , __func__);
8583
+ return -1 ;
8584
+ }
8585
+
8586
+ // temporarily allocate memory for the input batch if needed
8587
+ llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : lctx.kv_self .max_pos () + 1 );
8588
+ const llama_batch & batch = batch_allocr.batch ;
8589
+
8590
+ const auto & model = lctx.model ;
8591
+ const auto & vocab = model.vocab ;
8592
+ const auto & hparams = model.hparams ;
8593
+ const auto & cparams = lctx.cparams ;
8594
+
8595
+ if (lctx.t_compute_start_us == 0 ) {
8596
+ lctx.t_compute_start_us = ggml_time_us ();
8597
+ }
8598
+ auto & kv_self = lctx.kv_self ;
8599
+ llama_kv_slot_restorer kv_slot_restorer (kv_self);
8600
+
8601
+ const int64_t n_embd = hparams.n_embd ;
8602
+ const int64_t n_vocab = vocab.n_tokens ();
8603
+
8604
+ uint32_t n_outputs = 0 ;
8605
+ uint32_t n_outputs_prev = 0 ;
8606
+
8607
+ {
8608
+ const int ret = llama_prepare_sbatch (lctx, batch, n_outputs);
8609
+ if (ret != 0 ) {
8610
+ return ret;
8611
+ }
8612
+ }
8613
+
8614
+ while (lctx.sbatch .n_tokens > 0 ) {
8615
+ llama_ubatch ubatch;
8616
+ {
8617
+ const int ret = llama_prepare_ubatch (lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens );
8618
+ if (ret != 0 ) {
8619
+ return ret;
8588
8620
}
8589
8621
}
8590
8622
8623
+ const int n_threads = ubatch.n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch ;
8624
+ ggml_threadpool_t threadpool = ubatch.n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch ;
8625
+
8626
+ GGML_ASSERT (n_threads > 0 );
8627
+
8591
8628
// printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
8592
8629
8593
8630
ggml_backend_sched_reset (lctx.sched .get ());
@@ -8640,7 +8677,7 @@ static int llama_decode_impl(
8640
8677
8641
8678
// update the kv ring buffer
8642
8679
{
8643
- kv_self.head += n_tokens;
8680
+ kv_self.head += ubatch. n_tokens ;
8644
8681
8645
8682
// Ensure kv cache head points to a valid index.
8646
8683
if (kv_self.head >= kv_self.size ) {
0 commit comments