Skip to content

winogrande: evaluate log-probs in parallel #5036

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 32 additions & 37 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
return true;
}

static void hellaswag_compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
static void compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
const std::vector<std::pair<size_t, llama_token>>& eval_pairs, std::vector<float>& eval_results) {
constexpr int k_token_chunk = 4;
if (eval_results.size() != eval_pairs.size()) {
Expand Down Expand Up @@ -692,7 +692,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
}
}
// Then we do the actual calculation
hellaswag_compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);

size_t ir = 0;

Expand Down Expand Up @@ -898,6 +898,10 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
std::vector<float> tok_logits(n_vocab);
std::vector<float> batch_logits(n_vocab*n_ctx);

std::vector<std::pair<size_t, llama_token>> eval_pairs;
std::vector<float> eval_results;
std::vector<std::thread> workers(std::thread::hardware_concurrency());

int n_correct = 0;
int n_done = 0;

Expand Down Expand Up @@ -948,61 +952,52 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
return;
}

eval_pairs.clear();
for (size_t i = i0; i < i1; ++i) {
auto & task = data[i];

const bool skip_choice =
task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;

float score_1st = 0;
bool is_nan_1st = false;
const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
size_t li = n_base1 - 1;
for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(task.i_batch + li++), n_vocab*sizeof(float));
const float prob = softmax(tok_logits)[task.seq_tokens[0][j+1]];
if (std::isnan(prob) || !prob) {
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
prob, j, (task.first + task.choices[0] + task.second).c_str(), n_base1);
is_nan_1st = true;
break;
}
score_1st += std::log(prob);
eval_pairs.push_back(std::make_pair(task.i_batch + li++, task.seq_tokens[0][j+1]));
}
score_1st /= (task.seq_tokens[0].size() - n_base1 - last_1st);

float score_2nd = 0;
bool is_nan_2nd = false;
const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - 1;
for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(task.i_batch + li++), n_vocab*sizeof(float));
const float prob = softmax(tok_logits)[task.seq_tokens[1][j+1]];
if (std::isnan(prob) || !prob) {
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
prob, j, (task.first + task.choices[1] + task.second).c_str(), n_base2);
is_nan_2nd = true;
break;
}
score_2nd += std::log(prob);
eval_pairs.push_back(std::make_pair(task.i_batch + li++, task.seq_tokens[1][j+1]));
}
score_2nd /= (task.seq_tokens[1].size() - n_base2 - last_2nd);
}
compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);

size_t ir = 0;
for (size_t i = i0; i < i1; ++i) {
auto & task = data[i];

const bool skip_choice =
task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;

if (is_nan_1st || is_nan_2nd) {
continue;
float score_1st = 0;
const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
score_1st += eval_results[ir++];
}
score_1st /= (task.seq_tokens[0].size() - n_base1 - last_1st);

if (std::isnan(score_1st) || std::isnan(score_2nd)) {
printf("================== NaN score %g, %g) for:\n", score_1st, score_2nd);
printf("Q1: <%s> - %zu tokens\n", (task.first + task.choices[0] + task.second).c_str(), task.seq_tokens[0].size());
printf("Q2: <%s> - %zu tokens\n", (task.first + task.choices[1] + task.second).c_str(), task.seq_tokens[1].size());
printf("B : <%s> - %zu tokens\n", task.first.c_str(), task.common_prefix);
printf("base_1 has %zu tokens, base_2 has %zu tokens, skip_choice = %d\n", n_base1, n_base2, skip_choice);
continue;
float score_2nd = 0;
const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
score_2nd += eval_results[ir++];
}
score_2nd /= (task.seq_tokens[1].size() - n_base2 - last_2nd);

int result = score_1st > score_2nd ? 1 : 2;

Expand All @@ -1011,7 +1006,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
}
++n_done;

// Print the accumulated accuracy mean x 100
// print the accumulated accuracy mean x 100
printf("%zu\t%.4lf\t%10.6f %10.6f %d %d\n", i+1, 100.0 * n_correct/n_done, score_1st, score_2nd, result, task.answer);
fflush(stdout);
}
Expand Down