Skip to content

Commit 7051aac

Browse files
ikawrakowKawrakow
andauthored
winogrande: evaluate log-probs in parallel (#5036)
This is a relatively minor performance tweak resulting in ~10% speedup on my system. Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 2b3b999 commit 7051aac

File tree

1 file changed

+32
-37
lines changed

1 file changed

+32
-37
lines changed

examples/perplexity/perplexity.cpp

Lines changed: 32 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
458458
return true;
459459
}
460460

461-
static void hellaswag_compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
461+
static void compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
462462
const std::vector<std::pair<size_t, llama_token>>& eval_pairs, std::vector<float>& eval_results) {
463463
constexpr int k_token_chunk = 4;
464464
if (eval_results.size() != eval_pairs.size()) {
@@ -700,7 +700,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
700700
}
701701
}
702702
// Then we do the actual calculation
703-
hellaswag_compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
703+
compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
704704

705705
size_t ir = 0;
706706

@@ -906,6 +906,10 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
906906
std::vector<float> tok_logits(n_vocab);
907907
std::vector<float> batch_logits(n_vocab*n_ctx);
908908

909+
std::vector<std::pair<size_t, llama_token>> eval_pairs;
910+
std::vector<float> eval_results;
911+
std::vector<std::thread> workers(std::thread::hardware_concurrency());
912+
909913
int n_correct = 0;
910914
int n_done = 0;
911915

@@ -956,61 +960,52 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
956960
return;
957961
}
958962

963+
eval_pairs.clear();
959964
for (size_t i = i0; i < i1; ++i) {
960965
auto & task = data[i];
961966

962967
const bool skip_choice =
963968
task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
964969
task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;
965970

966-
float score_1st = 0;
967-
bool is_nan_1st = false;
968971
const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
969972
const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
970973
size_t li = n_base1 - 1;
971974
for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
972-
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(task.i_batch + li++), n_vocab*sizeof(float));
973-
const float prob = softmax(tok_logits)[task.seq_tokens[0][j+1]];
974-
if (std::isnan(prob) || !prob) {
975-
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
976-
prob, j, (task.first + task.choices[0] + task.second).c_str(), n_base1);
977-
is_nan_1st = true;
978-
break;
979-
}
980-
score_1st += std::log(prob);
975+
eval_pairs.push_back(std::make_pair(task.i_batch + li++, task.seq_tokens[0][j+1]));
981976
}
982-
score_1st /= (task.seq_tokens[0].size() - n_base1 - last_1st);
983-
984-
float score_2nd = 0;
985-
bool is_nan_2nd = false;
986977
const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
987978
const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
988979
li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - 1;
989980
for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
990-
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(task.i_batch + li++), n_vocab*sizeof(float));
991-
const float prob = softmax(tok_logits)[task.seq_tokens[1][j+1]];
992-
if (std::isnan(prob) || !prob) {
993-
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
994-
prob, j, (task.first + task.choices[1] + task.second).c_str(), n_base2);
995-
is_nan_2nd = true;
996-
break;
997-
}
998-
score_2nd += std::log(prob);
981+
eval_pairs.push_back(std::make_pair(task.i_batch + li++, task.seq_tokens[1][j+1]));
999982
}
1000-
score_2nd /= (task.seq_tokens[1].size() - n_base2 - last_2nd);
983+
}
984+
compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
985+
986+
size_t ir = 0;
987+
for (size_t i = i0; i < i1; ++i) {
988+
auto & task = data[i];
989+
990+
const bool skip_choice =
991+
task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
992+
task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;
1001993

1002-
if (is_nan_1st || is_nan_2nd) {
1003-
continue;
994+
float score_1st = 0;
995+
const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
996+
const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
997+
for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
998+
score_1st += eval_results[ir++];
1004999
}
1000+
score_1st /= (task.seq_tokens[0].size() - n_base1 - last_1st);
10051001

1006-
if (std::isnan(score_1st) || std::isnan(score_2nd)) {
1007-
printf("================== NaN score %g, %g) for:\n", score_1st, score_2nd);
1008-
printf("Q1: <%s> - %zu tokens\n", (task.first + task.choices[0] + task.second).c_str(), task.seq_tokens[0].size());
1009-
printf("Q2: <%s> - %zu tokens\n", (task.first + task.choices[1] + task.second).c_str(), task.seq_tokens[1].size());
1010-
printf("B : <%s> - %zu tokens\n", task.first.c_str(), task.common_prefix);
1011-
printf("base_1 has %zu tokens, base_2 has %zu tokens, skip_choice = %d\n", n_base1, n_base2, skip_choice);
1012-
continue;
1002+
float score_2nd = 0;
1003+
const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
1004+
const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
1005+
for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
1006+
score_2nd += eval_results[ir++];
10131007
}
1008+
score_2nd /= (task.seq_tokens[1].size() - n_base2 - last_2nd);
10141009

10151010
int result = score_1st > score_2nd ? 1 : 2;
10161011

@@ -1019,7 +1014,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
10191014
}
10201015
++n_done;
10211016

1022-
// Print the accumulated accuracy mean x 100
1017+
// print the accumulated accuracy mean x 100
10231018
printf("%zu\t%.4lf\t%10.6f %10.6f %d %d\n", i+1, 100.0 * n_correct/n_done, score_1st, score_2nd, result, task.answer);
10241019
fflush(stdout);
10251020
}

0 commit comments

Comments
 (0)