Skip to content

Commit 62959e7

Browse files
ikawrakowKawrakow
andauthored
Strided perplexity (#2714)
* Implementing strided computation of perplexity * Alternative way to output PPL results --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 7f7ddd5 commit 62959e7

File tree

3 files changed

+141
-1
lines changed

3 files changed

+141
-1
lines changed

common/common.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
417417
params.antiprompt.push_back(argv[i]);
418418
} else if (arg == "--perplexity") {
419419
params.perplexity = true;
420+
} else if (arg == "--ppl-stride") {
421+
if (++i >= argc) {
422+
invalid_param = true;
423+
break;
424+
}
425+
params.ppl_stride = std::stoi(argv[i]);
426+
} else if (arg == "--ppl-output-type") {
427+
if (++i >= argc) {
428+
invalid_param = true;
429+
break;
430+
}
431+
params.ppl_output_type = std::stoi(argv[i]);
420432
} else if (arg == "--hellaswag") {
421433
params.hellaswag = true;
422434
} else if (arg == "--hellaswag-tasks") {

common/common.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ struct gpt_params {
6464
std::string lora_adapter = ""; // lora adapter path
6565
std::string lora_base = ""; // base model path for the lora adapter
6666

67+
int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
68+
int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
69+
// (which is more convenient to use for plotting)
70+
//
6771
bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
6872
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
6973

examples/perplexity/perplexity.cpp

Lines changed: 125 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,121 @@ std::vector<float> softmax(const std::vector<float>& logits) {
2727
return probs;
2828
}
2929

30+
void perplexity_v2(llama_context * ctx, const gpt_params & params) {
31+
32+
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
33+
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
34+
// Output: `perplexity: 13.5106 [114/114]`
35+
// BOS tokens will be added for each chunk before eval
36+
37+
if (params.ppl_stride <= 0) {
38+
fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride);
39+
return;
40+
}
41+
auto tokens = ::llama_tokenize(ctx, params.prompt, true);
42+
43+
const int calc_chunk = params.n_ctx;
44+
45+
fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk);
46+
47+
if (int(tokens.size()) <= calc_chunk) {
48+
fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
49+
tokens.size(), params.n_ctx, params.ppl_stride);
50+
return;
51+
}
52+
53+
const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride;
54+
55+
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
56+
const int n_vocab = llama_n_vocab(ctx);
57+
const int n_batch = params.n_batch;
58+
59+
int count = 0;
60+
double nll = 0.0;
61+
62+
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
63+
64+
for (int i = 0; i < n_chunk; ++i) {
65+
const int start = i * params.ppl_stride;
66+
const int end = start + calc_chunk;
67+
68+
const int num_batches = (calc_chunk + n_batch - 1) / n_batch;
69+
//fprintf(stderr, "%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches);
70+
71+
std::vector<float> logits;
72+
73+
const auto t_start = std::chrono::high_resolution_clock::now();
74+
75+
for (int j = 0; j < num_batches; ++j) {
76+
const int batch_start = start + j * n_batch;
77+
const int batch_size = std::min(end - batch_start, n_batch);
78+
79+
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
80+
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
81+
//fprintf(stderr, "%s : failed to eval\n", __func__);
82+
return;
83+
}
84+
85+
// save original token and restore it after eval
86+
const auto token_org = tokens[batch_start];
87+
88+
// add BOS token for the first batch of each chunk
89+
if (j == 0) {
90+
tokens[batch_start] = llama_token_bos(ctx);
91+
}
92+
93+
const auto batch_logits = llama_get_logits(ctx);
94+
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
95+
96+
if (j == 0) {
97+
tokens[batch_start] = token_org;
98+
}
99+
}
100+
101+
const auto t_end = std::chrono::high_resolution_clock::now();
102+
103+
if (i == 0) {
104+
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
105+
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
106+
int total_seconds = (int)(t_total * n_chunk);
107+
if (total_seconds >= 60*60) {
108+
fprintf(stderr, "%d hours ", total_seconds / (60*60));
109+
total_seconds = total_seconds % (60*60);
110+
}
111+
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
112+
}
113+
114+
//fprintf(stderr, "%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
115+
for (int j = params.n_ctx - params.ppl_stride - 1; j < params.n_ctx - 1; ++j) {
116+
117+
// Calculate probability of next token, given the previous ones.
118+
const std::vector<float> tok_logits(
119+
logits.begin() + (j + 0) * n_vocab,
120+
logits.begin() + (j + 1) * n_vocab);
121+
122+
const float prob = softmax(tok_logits)[tokens[start + j + 1]];
123+
124+
nll += -std::log(prob);
125+
++count;
126+
}
127+
// perplexity is e^(average negative log-likelihood)
128+
if (params.ppl_output_type == 0) {
129+
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
130+
} else {
131+
printf("%8d %.4lf\n", i*params.ppl_stride, std::exp(nll / count));
132+
}
133+
fflush(stdout);
134+
}
135+
printf("\n");
136+
}
137+
30138
void perplexity(llama_context * ctx, const gpt_params & params) {
139+
140+
if (params.ppl_stride > 0) {
141+
perplexity_v2(ctx, params);
142+
return;
143+
}
144+
31145
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
32146
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
33147
// Output: `perplexity: 13.5106 [114/114]`
@@ -116,7 +230,11 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
116230
++count;
117231
}
118232
// perplexity is e^(average negative log-likelihood)
119-
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
233+
if (params.ppl_output_type == 0) {
234+
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
235+
} else {
236+
printf("%8d %.4lf\n", i*params.n_ctx, std::exp(nll / count));
237+
}
120238
fflush(stdout);
121239
}
122240
printf("\n");
@@ -369,6 +487,12 @@ int main(int argc, char ** argv) {
369487
params.perplexity = true;
370488
params.n_batch = std::min(params.n_batch, params.n_ctx);
371489

490+
if (params.ppl_stride > 0) {
491+
fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",
492+
params.n_ctx, params.n_ctx + params.ppl_stride/2);
493+
params.n_ctx += params.ppl_stride/2;
494+
}
495+
372496
if (params.n_ctx > 2048) {
373497
fprintf(stderr, "%s: warning: model might not support context sizes greater than 2048 tokens (%d specified);"
374498
"expect poor results\n", __func__, params.n_ctx);

0 commit comments

Comments
 (0)