@@ -17138,10 +17138,10 @@ static void llama_graph_compute(
17138
17138
//
17139
17139
static int llama_decode_internal(
17140
17140
llama_context & lctx,
17141
- llama_batch batch_all ) { // TODO: rename back to batch
17141
+ llama_batch batch ) {
17142
17142
17143
17143
lctx.is_encoding = false;
17144
- const uint32_t n_tokens_all = batch_all .n_tokens;
17144
+ const uint32_t n_tokens_all = batch .n_tokens;
17145
17145
17146
17146
if (n_tokens_all == 0) {
17147
17147
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
@@ -17152,12 +17152,12 @@ static int llama_decode_internal(
17152
17152
const auto & hparams = model.hparams;
17153
17153
const auto & cparams = lctx.cparams;
17154
17154
17155
- GGML_ASSERT((!batch_all .token && batch_all .embd) || (batch_all .token && !batch_all .embd)); // NOLINT
17155
+ GGML_ASSERT((!batch .token && batch .embd) || (batch .token && !batch .embd)); // NOLINT
17156
17156
17157
- if (batch_all .token) {
17157
+ if (batch .token) {
17158
17158
for (uint32_t i = 0; i < n_tokens_all; ++i) {
17159
- if (batch_all .token[i] < 0 || (uint32_t)batch_all .token[i] >= model.vocab.n_vocab) {
17160
- LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch_all .token[i]);
17159
+ if (batch .token[i] < 0 || (uint32_t)batch .token[i] >= model.vocab.n_vocab) {
17160
+ LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch .token[i]);
17161
17161
return -1;
17162
17162
}
17163
17163
}
@@ -17188,9 +17188,9 @@ static int llama_decode_internal(
17188
17188
lctx.embd_seq.clear();
17189
17189
17190
17190
// count outputs
17191
- if (batch_all .logits && !embd_pooled) {
17191
+ if (batch .logits && !embd_pooled) {
17192
17192
for (uint32_t i = 0; i < n_tokens_all; ++i) {
17193
- n_outputs += batch_all .logits[i] != 0;
17193
+ n_outputs += batch .logits[i] != 0;
17194
17194
}
17195
17195
} else if (lctx.logits_all || embd_pooled) {
17196
17196
n_outputs = n_tokens_all;
@@ -17199,7 +17199,7 @@ static int llama_decode_internal(
17199
17199
n_outputs = 1;
17200
17200
}
17201
17201
17202
- lctx.sbatch.from_batch(batch_all , n_embd,
17202
+ lctx.sbatch.from_batch(batch , n_embd,
17203
17203
/* simple_split */ !kv_self.recurrent,
17204
17204
/* logits_all */ n_outputs == n_tokens_all);
17205
17205
0 commit comments