@@ -27,7 +27,121 @@ std::vector<float> softmax(const std::vector<float>& logits) {
27
27
return probs;
28
28
}
29
29
30
+ void perplexity_v2 (llama_context * ctx, const gpt_params & params) {
31
+
32
+ // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
33
+ // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
34
+ // Output: `perplexity: 13.5106 [114/114]`
35
+ // BOS tokens will be added for each chunk before eval
36
+
37
+ if (params.ppl_stride <= 0 ) {
38
+ fprintf (stderr, " %s: stride is %d but must be greater than zero!\n " ,__func__,params.ppl_stride );
39
+ return ;
40
+ }
41
+ auto tokens = ::llama_tokenize (ctx, params.prompt , true );
42
+
43
+ const int calc_chunk = params.n_ctx ;
44
+
45
+ fprintf (stderr, " %s: have %zu tokens. Calculation chunk = %d\n " , __func__, tokens.size (), calc_chunk);
46
+
47
+ if (int (tokens.size ()) <= calc_chunk) {
48
+ fprintf (stderr, " %s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n " ,__func__,
49
+ tokens.size (), params.n_ctx , params.ppl_stride );
50
+ return ;
51
+ }
52
+
53
+ const int n_chunk_max = (tokens.size () - calc_chunk + params.ppl_stride - 1 ) / params.ppl_stride ;
54
+
55
+ const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min (params.n_chunks , n_chunk_max);
56
+ const int n_vocab = llama_n_vocab (ctx);
57
+ const int n_batch = params.n_batch ;
58
+
59
+ int count = 0 ;
60
+ double nll = 0.0 ;
61
+
62
+ fprintf (stderr, " %s: calculating perplexity over %d chunks, batch_size=%d\n " , __func__, n_chunk, n_batch);
63
+
64
+ for (int i = 0 ; i < n_chunk; ++i) {
65
+ const int start = i * params.ppl_stride ;
66
+ const int end = start + calc_chunk;
67
+
68
+ const int num_batches = (calc_chunk + n_batch - 1 ) / n_batch;
69
+ // fprintf(stderr, "%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches);
70
+
71
+ std::vector<float > logits;
72
+
73
+ const auto t_start = std::chrono::high_resolution_clock::now ();
74
+
75
+ for (int j = 0 ; j < num_batches; ++j) {
76
+ const int batch_start = start + j * n_batch;
77
+ const int batch_size = std::min (end - batch_start, n_batch);
78
+
79
+ // fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
80
+ if (llama_eval (ctx, tokens.data () + batch_start, batch_size, j * n_batch, params.n_threads )) {
81
+ // fprintf(stderr, "%s : failed to eval\n", __func__);
82
+ return ;
83
+ }
84
+
85
+ // save original token and restore it after eval
86
+ const auto token_org = tokens[batch_start];
87
+
88
+ // add BOS token for the first batch of each chunk
89
+ if (j == 0 ) {
90
+ tokens[batch_start] = llama_token_bos (ctx);
91
+ }
92
+
93
+ const auto batch_logits = llama_get_logits (ctx);
94
+ logits.insert (logits.end (), batch_logits, batch_logits + batch_size * n_vocab);
95
+
96
+ if (j == 0 ) {
97
+ tokens[batch_start] = token_org;
98
+ }
99
+ }
100
+
101
+ const auto t_end = std::chrono::high_resolution_clock::now ();
102
+
103
+ if (i == 0 ) {
104
+ const float t_total = std::chrono::duration<float >(t_end - t_start).count ();
105
+ fprintf (stderr, " %s: %.2f seconds per pass - ETA " , __func__, t_total);
106
+ int total_seconds = (int )(t_total * n_chunk);
107
+ if (total_seconds >= 60 *60 ) {
108
+ fprintf (stderr, " %d hours " , total_seconds / (60 *60 ));
109
+ total_seconds = total_seconds % (60 *60 );
110
+ }
111
+ fprintf (stderr, " %.2f minutes\n " , total_seconds / 60.0 );
112
+ }
113
+
114
+ // fprintf(stderr, "%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
115
+ for (int j = params.n_ctx - params.ppl_stride - 1 ; j < params.n_ctx - 1 ; ++j) {
116
+
117
+ // Calculate probability of next token, given the previous ones.
118
+ const std::vector<float > tok_logits (
119
+ logits.begin () + (j + 0 ) * n_vocab,
120
+ logits.begin () + (j + 1 ) * n_vocab);
121
+
122
+ const float prob = softmax (tok_logits)[tokens[start + j + 1 ]];
123
+
124
+ nll += -std::log (prob);
125
+ ++count;
126
+ }
127
+ // perplexity is e^(average negative log-likelihood)
128
+ if (params.ppl_output_type == 0 ) {
129
+ printf (" [%d]%.4lf," , i + 1 , std::exp (nll / count));
130
+ } else {
131
+ printf (" %8d %.4lf\n " , i*params.ppl_stride , std::exp (nll / count));
132
+ }
133
+ fflush (stdout);
134
+ }
135
+ printf (" \n " );
136
+ }
137
+
30
138
void perplexity (llama_context * ctx, const gpt_params & params) {
139
+
140
+ if (params.ppl_stride > 0 ) {
141
+ perplexity_v2 (ctx, params);
142
+ return ;
143
+ }
144
+
31
145
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
32
146
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
33
147
// Output: `perplexity: 13.5106 [114/114]`
@@ -116,7 +230,11 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
116
230
++count;
117
231
}
118
232
// perplexity is e^(average negative log-likelihood)
119
- printf (" [%d]%.4lf," , i + 1 , std::exp (nll / count));
233
+ if (params.ppl_output_type == 0 ) {
234
+ printf (" [%d]%.4lf," , i + 1 , std::exp (nll / count));
235
+ } else {
236
+ printf (" %8d %.4lf\n " , i*params.n_ctx , std::exp (nll / count));
237
+ }
120
238
fflush (stdout);
121
239
}
122
240
printf (" \n " );
@@ -369,6 +487,12 @@ int main(int argc, char ** argv) {
369
487
params.perplexity = true ;
370
488
params.n_batch = std::min (params.n_batch , params.n_ctx );
371
489
490
+ if (params.ppl_stride > 0 ) {
491
+ fprintf (stderr, " Will perform strided perplexity calculation -> adjusting context size from %d to %d\n " ,
492
+ params.n_ctx , params.n_ctx + params.ppl_stride /2 );
493
+ params.n_ctx += params.ppl_stride /2 ;
494
+ }
495
+
372
496
if (params.n_ctx > 2048 ) {
373
497
fprintf (stderr, " %s: warning: model might not support context sizes greater than 2048 tokens (%d specified);"
374
498
" expect poor results\n " , __func__, params.n_ctx );
0 commit comments