Skip to content

Commit 91d71fe

Browse files
committed
More accurate perplexity calculation - over all logits in the context window (so 512x more tokens!)
1 parent e94bd9c commit 91d71fe

File tree

1 file changed

+35
-9
lines changed

1 file changed

+35
-9
lines changed

main.cpp

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,8 @@ bool llama_eval(
527527
const int n_past,
528528
const std::vector<gpt_vocab::id> & embd_inp,
529529
std::vector<float> & embd_w,
530-
size_t & mem_per_token) {
530+
size_t & mem_per_token,
531+
bool return_all_logits = false) {
531532
const int N = embd_inp.size();
532533

533534
const auto & hparams = model.hparams;
@@ -733,9 +734,14 @@ bool llama_eval(
733734
//embd_w.resize(n_vocab*N);
734735
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
735736

736-
// return result for just the last token
737-
embd_w.resize(n_vocab);
738-
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
737+
if (return_all_logits) {
738+
embd_w.resize(n_vocab * N);
739+
memcpy(embd_w.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
740+
} else {
741+
// return result for just the last token
742+
embd_w.resize(n_vocab);
743+
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
744+
}
739745

740746
if (mem_per_token == 0) {
741747
mem_per_token = ggml_used_mem(ctx0)/N;
@@ -769,22 +775,42 @@ void perplexity(const gpt_vocab &vocab, const llama_model &model, const gpt_para
769775
// Output: `perplexity: 13.5106 [114/114]`
770776
std::vector<gpt_vocab::id> tokens = ::llama_tokenize(vocab, params.prompt, true);
771777

778+
int count = 0;
772779
double nll = 0.0;
773780
int seq_count = tokens.size() / params.n_ctx;
774781
for (int i = 0; i < seq_count; ++i) {
775782
int start = i * params.n_ctx;
776783
int end = start + params.n_ctx - 1;
777784
std::vector<gpt_vocab::id> embd(tokens.begin() + start, tokens.begin() + end);
778785
std::vector<float> logits;
779-
if (!llama_eval(model, params.n_threads, 0, embd, logits, mem_per_token)) {
786+
if (!llama_eval(model, params.n_threads, 0, embd, logits, mem_per_token, true)) {
780787
fprintf(stderr, "Failed to predict\n");
781788
return;
782789
}
783-
// Calculate probability of next token, given the previous ones.
784-
double prob = softmax(logits)[tokens[end]];
785-
nll += -std::log(prob);
790+
// We get the logits for all the tokens in the context window (params.n_ctx)
791+
// from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
792+
// calculate the perplexity over the last half the window (so the model always has
793+
// some context to predict the token).
794+
//
795+
// We rely on the fact that attention in the forward pass only looks at previous
796+
// tokens here, so the logits returned for each token are an accurate representation
797+
// of what the model would have predicted at that point.
798+
//
799+
// Example, we have a context window of 512, we will compute perplexity for each of the
800+
// last 256 tokens. Then, we split the input up into context window size chunks to
801+
// process the entire prompt.
802+
for (int j = params.n_ctx / 2; j < params.n_ctx - 1; ++j) {
803+
// Calculate probability of next token, given the previous ones.
804+
int n_vocab = model.hparams.n_vocab;
805+
std::vector<float> tok_logits(
806+
logits.begin() + j * n_vocab,
807+
logits.begin() + (j + 1) * n_vocab);
808+
double prob = softmax(tok_logits)[tokens[start + j + 1]];
809+
nll += -std::log(prob);
810+
++count;
811+
}
786812
// perplexity is e^(average negative log-likelihood)
787-
printf("perplexity: %.4lf [%d/%d] \r", std::exp(nll / (i + 1)), i + 1, seq_count);
813+
printf("perplexity: %.4lf [%d/%d] \r", std::exp(nll / count), i + 1, seq_count);
788814
fflush(stdout);
789815
}
790816
printf("\n");

0 commit comments

Comments
 (0)