6
6
#include < ctime>
7
7
#include < sstream>
8
8
#include < cstring>
9
+ #include < thread>
10
+ #include < mutex>
9
11
10
12
#if defined(_MSC_VER)
11
13
#pragma warning(disable: 4244 4267) // possible loss of data
@@ -27,6 +29,40 @@ std::vector<float> softmax(const std::vector<float>& logits) {
27
29
return probs;
28
30
}
29
31
32
+ float log_softmax (int n_vocab, const float * logits, int tok) {
33
+ float max_logit = logits[0 ];
34
+ for (int i = 1 ; i < n_vocab; ++i) max_logit = std::max (max_logit, logits[i]);
35
+ double sum_exp = 0.0 ;
36
+ for (int i = 0 ; i < n_vocab; ++i) sum_exp += expf (logits[i] - max_logit);
37
+ return logits[tok] - max_logit - log (sum_exp);
38
+ }
39
+
40
+ void process_logits (int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread>& workers,
41
+ double & nll, double & nll2) {
42
+
43
+ std::mutex mutex;
44
+ int counter = 0 ;
45
+ auto compute = [&mutex, &counter, &nll, &nll2, n_vocab, logits, tokens, n_token] () {
46
+ double local_nll = 0 , local_nll2 = 0 ;
47
+ while (true ) {
48
+ std::unique_lock<std::mutex> lock (mutex);
49
+ int i = counter++;
50
+ if (i >= n_token) {
51
+ nll += local_nll; nll2 += local_nll2;
52
+ break ;
53
+ }
54
+ lock.unlock ();
55
+ double v = -log_softmax (n_vocab, logits + i*n_vocab, tokens[i+1 ]);
56
+ local_nll += v;
57
+ local_nll2 += v*v;
58
+ }
59
+ };
60
+ for (auto & w : workers) w = std::thread (compute);
61
+ compute ();
62
+ for (auto & w : workers) w.join ();
63
+
64
+ }
65
+
30
66
void perplexity_v2 (llama_context * ctx, const gpt_params & params) {
31
67
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
32
68
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
@@ -166,9 +202,12 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
166
202
167
203
int count = 0 ;
168
204
double nll = 0.0 ;
205
+ double nll2 = 0.0 ;
169
206
170
207
fprintf (stderr, " %s: calculating perplexity over %d chunks, batch_size=%d\n " , __func__, n_chunk, n_batch);
171
208
209
+ std::vector<std::thread> workers (std::thread::hardware_concurrency () - 1 );
210
+
172
211
for (int i = 0 ; i < n_chunk; ++i) {
173
212
const int start = i * params.n_ctx ;
174
213
const int end = start + params.n_ctx ;
@@ -228,26 +267,32 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
228
267
// Example, we have a context window of 512, we will compute perplexity for each of the
229
268
// last 256 tokens. Then, we split the input up into context window size chunks to
230
269
// process the entire prompt.
231
- for (int j = std::min (512 , params.n_ctx / 2 ); j < params.n_ctx - 1 ; ++j) {
232
- // Calculate probability of next token, given the previous ones.
233
- const std::vector<float > tok_logits (
234
- logits.begin () + (j + 0 ) * n_vocab,
235
- logits.begin () + (j + 1 ) * n_vocab);
236
-
237
- const float prob = softmax (tok_logits)[tokens[start + j + 1 ]];
270
+ const int first = std::min (512 , params.n_ctx /2 );
271
+ process_logits (n_vocab, logits.data () + first*n_vocab, tokens.data () + start + first, params.n_ctx - 1 - first, workers, nll, nll2);
272
+ count += params.n_ctx - first - 1 ;
238
273
239
- nll += -std::log (prob);
240
- ++count;
241
- }
242
274
// perplexity is e^(average negative log-likelihood)
243
275
if (params.ppl_output_type == 0 ) {
244
276
printf (" [%d]%.4lf," , i + 1 , std::exp (nll / count));
245
277
} else {
246
- printf (" %8d %.4lf\n " , i*params.n_ctx , std::exp (nll / count));
278
+ double av = nll/count;
279
+ double av2 = nll2/count - av*av;
280
+ if (av2 > 0 ) av2 = sqrt (av2/(count-1 ));
281
+ printf (" %8d %.4lf %4lf %4lf\n " , i*params.n_ctx , std::exp (nll / count), av, av2);
247
282
}
248
283
fflush (stdout);
249
284
}
250
285
printf (" \n " );
286
+ nll2 /= count;
287
+ nll /= count;
288
+ nll2 -= nll * nll;
289
+ if (nll2 > 0 ) {
290
+ nll2 = sqrt (nll2/(count-1 ));
291
+ double ppl = exp (nll);
292
+ printf (" Final estimate: PPL = %.4lf +/- %.5lf\n " , ppl, nll2*ppl);
293
+ } else {
294
+ printf (" Unexpected negative standard deviation of log(prob)\n " );
295
+ }
251
296
}
252
297
253
298
std::vector<float > hellaswag_evaluate_tokens (llama_context * ctx, const std::vector<int >& tokens, int n_past, int n_batch,
0 commit comments