@@ -527,7 +527,8 @@ bool llama_eval(
527
527
const int n_past,
528
528
const std::vector<gpt_vocab::id> & embd_inp,
529
529
std::vector<float > & embd_w,
530
- size_t & mem_per_token) {
530
+ size_t & mem_per_token,
531
+ bool return_all_logits = false ) {
531
532
const int N = embd_inp.size ();
532
533
533
534
const auto & hparams = model.hparams ;
@@ -733,9 +734,14 @@ bool llama_eval(
733
734
// embd_w.resize(n_vocab*N);
734
735
// memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
735
736
736
- // return result for just the last token
737
- embd_w.resize (n_vocab);
738
- memcpy (embd_w.data (), (float *) ggml_get_data (inpL) + (n_vocab*(N-1 )), sizeof (float )*n_vocab);
737
+ if (return_all_logits) {
738
+ embd_w.resize (n_vocab * N);
739
+ memcpy (embd_w.data (), (float *) ggml_get_data (inpL), sizeof (float )*n_vocab*N);
740
+ } else {
741
+ // return result for just the last token
742
+ embd_w.resize (n_vocab);
743
+ memcpy (embd_w.data (), (float *) ggml_get_data (inpL) + (n_vocab*(N-1 )), sizeof (float )*n_vocab);
744
+ }
739
745
740
746
if (mem_per_token == 0 ) {
741
747
mem_per_token = ggml_used_mem (ctx0)/N;
@@ -769,22 +775,42 @@ void perplexity(const gpt_vocab &vocab, const llama_model &model, const gpt_para
769
775
// Output: `perplexity: 13.5106 [114/114]`
770
776
std::vector<gpt_vocab::id> tokens = ::llama_tokenize (vocab, params.prompt , true );
771
777
778
+ int count = 0 ;
772
779
double nll = 0.0 ;
773
780
int seq_count = tokens.size () / params.n_ctx ;
774
781
for (int i = 0 ; i < seq_count; ++i) {
775
782
int start = i * params.n_ctx ;
776
783
int end = start + params.n_ctx - 1 ;
777
784
std::vector<gpt_vocab::id> embd (tokens.begin () + start, tokens.begin () + end);
778
785
std::vector<float > logits;
779
- if (!llama_eval (model, params.n_threads , 0 , embd, logits, mem_per_token)) {
786
+ if (!llama_eval (model, params.n_threads , 0 , embd, logits, mem_per_token, true )) {
780
787
fprintf (stderr, " Failed to predict\n " );
781
788
return ;
782
789
}
783
- // Calculate probability of next token, given the previous ones.
784
- double prob = softmax (logits)[tokens[end]];
785
- nll += -std::log (prob);
790
+ // We get the logits for all the tokens in the context window (params.n_ctx)
791
+ // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
792
+ // calculate the perplexity over the last half the window (so the model always has
793
+ // some context to predict the token).
794
+ //
795
+ // We rely on the fact that attention in the forward pass only looks at previous
796
+ // tokens here, so the logits returned for each token are an accurate representation
797
+ // of what the model would have predicted at that point.
798
+ //
799
+ // Example, we have a context window of 512, we will compute perplexity for each of the
800
+ // last 256 tokens. Then, we split the input up into context window size chunks to
801
+ // process the entire prompt.
802
+ for (int j = params.n_ctx / 2 ; j < params.n_ctx - 1 ; ++j) {
803
+ // Calculate probability of next token, given the previous ones.
804
+ int n_vocab = model.hparams .n_vocab ;
805
+ std::vector<float > tok_logits (
806
+ logits.begin () + j * n_vocab,
807
+ logits.begin () + (j + 1 ) * n_vocab);
808
+ double prob = softmax (tok_logits)[tokens[start + j + 1 ]];
809
+ nll += -std::log (prob);
810
+ ++count;
811
+ }
786
812
// perplexity is e^(average negative log-likelihood)
787
- printf (" perplexity: %.4lf [%d/%d] \r " , std::exp (nll / (i + 1 ) ), i + 1 , seq_count);
813
+ printf (" perplexity: %.4lf [%d/%d] \r " , std::exp (nll / count ), i + 1 , seq_count);
788
814
fflush (stdout);
789
815
}
790
816
printf (" \n " );
0 commit comments