@@ -458,23 +458,24 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
458
458
return true ;
459
459
}
460
460
461
+ #define K_TOKEN_CHUNK 4
462
+
461
463
static void compute_logprobs (const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
462
464
const std::vector<std::pair<size_t , llama_token>>& eval_pairs, std::vector<float >& eval_results) {
463
- constexpr int k_token_chunk = 4 ;
464
465
if (eval_results.size () != eval_pairs.size ()) {
465
466
eval_results.resize (eval_pairs.size ());
466
467
}
467
468
if (eval_pairs.empty ()) return ;
468
469
469
- size_t max_threads = std::min ((eval_pairs.size () + k_token_chunk - 1 )/k_token_chunk , workers.size ());
470
+ size_t max_threads = std::min ((eval_pairs.size () + K_TOKEN_CHUNK - 1 )/K_TOKEN_CHUNK , workers.size ());
470
471
471
472
std::atomic<int > counter (0 );
472
473
auto compute = [&counter, &eval_pairs, &eval_results, batch_logits, n_vocab] () {
473
- float local_logprobs[k_token_chunk ];
474
+ float local_logprobs[K_TOKEN_CHUNK ];
474
475
while (true ) {
475
- size_t first = counter.fetch_add (k_token_chunk , std::memory_order_relaxed);
476
+ size_t first = counter.fetch_add (K_TOKEN_CHUNK , std::memory_order_relaxed);
476
477
if (first >= eval_results.size ()) break ;
477
- size_t last = std::min (first + k_token_chunk , eval_results.size ());
478
+ size_t last = std::min (first + K_TOKEN_CHUNK , eval_results.size ());
478
479
for (size_t i = first; i < last; ++i) {
479
480
auto logits = batch_logits + eval_pairs[i].first * n_vocab;
480
481
float max_logit = logits[0 ];
@@ -497,7 +498,6 @@ static void compute_logprobs(const float * batch_logits, int n_vocab, std::vecto
497
498
for (size_t it = 0 ; it < max_threads; ++it) {
498
499
workers[it].join ();
499
500
}
500
-
501
501
}
502
502
503
503
static void hellaswag_score (llama_context * ctx, const gpt_params & params) {
0 commit comments