Skip to content

Commit e94bd9c

Browse files
committed
Compute perplexity over prompt
1 parent d3f202d commit e94bd9c

File tree

3 files changed

+67
-15
lines changed

3 files changed

+67
-15
lines changed

main.cpp

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ bool llama_eval(
547547
static void * buf = malloc(buf_size);
548548

549549
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
550-
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
550+
const size_t buf_size_new = 1.3*(mem_per_token*N); // add 30% to account for ggml object overhead
551551
//fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
552552

553553
// reallocate
@@ -747,6 +747,49 @@ bool llama_eval(
747747
return true;
748748
}
749749

750+
std::vector<double> softmax(const std::vector<float>& logits) {
751+
std::vector<double> probs(logits.size());
752+
float max_logit = logits[0];
753+
for (float v : logits) max_logit = std::max(max_logit, v);
754+
double sum_exp = 0.0;
755+
for (size_t i = 0; i < logits.size(); i++) {
756+
// Subtract the maximum logit value from the current logit value for numerical stability
757+
float logit = logits[i] - max_logit;
758+
double exp_logit = std::exp(logit);
759+
sum_exp += exp_logit;
760+
probs[i] = exp_logit;
761+
}
762+
for (size_t i = 0; i < probs.size(); i++) probs[i] /= sum_exp;
763+
return probs;
764+
}
765+
766+
void perplexity(const gpt_vocab &vocab, const llama_model &model, const gpt_params &params, size_t mem_per_token) {
767+
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
768+
// Run `./main --perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
769+
// Output: `perplexity: 13.5106 [114/114]`
770+
std::vector<gpt_vocab::id> tokens = ::llama_tokenize(vocab, params.prompt, true);
771+
772+
double nll = 0.0;
773+
int seq_count = tokens.size() / params.n_ctx;
774+
for (int i = 0; i < seq_count; ++i) {
775+
int start = i * params.n_ctx;
776+
int end = start + params.n_ctx - 1;
777+
std::vector<gpt_vocab::id> embd(tokens.begin() + start, tokens.begin() + end);
778+
std::vector<float> logits;
779+
if (!llama_eval(model, params.n_threads, 0, embd, logits, mem_per_token)) {
780+
fprintf(stderr, "Failed to predict\n");
781+
return;
782+
}
783+
// Calculate probability of next token, given the previous ones.
784+
double prob = softmax(logits)[tokens[end]];
785+
nll += -std::log(prob);
786+
// perplexity is e^(average negative log-likelihood)
787+
printf("perplexity: %.4lf [%d/%d] \r", std::exp(nll / (i + 1)), i + 1, seq_count);
788+
fflush(stdout);
789+
}
790+
printf("\n");
791+
}
792+
750793
static bool is_interacting = false;
751794

752795
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
@@ -815,7 +858,7 @@ int main(int argc, char ** argv) {
815858
// load the model
816859
{
817860
const int64_t t_start_us = ggml_time_us();
818-
if (!llama_model_load(params.model, model, vocab, params.n_ctx)) {
861+
if (!llama_model_load(params.model, model, vocab, params.n_ctx)) {
819862
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
820863
return 1;
821864
}
@@ -830,13 +873,22 @@ int main(int argc, char ** argv) {
830873
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
831874
}
832875

876+
std::vector<float> logits;
877+
878+
// determine the required inference memory per token:
879+
size_t mem_per_token = 0;
880+
llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
881+
882+
if (params.perplexity) {
883+
perplexity(vocab, model, params, mem_per_token);
884+
exit(0);
885+
}
886+
833887
int n_past = 0;
834888

835889
int64_t t_sample_us = 0;
836890
int64_t t_predict_us = 0;
837891

838-
std::vector<float> logits;
839-
840892
// Add a space in front of the first character to match OG llama tokenizer behavior
841893
params.prompt.insert(0, 1, ' ');
842894
// tokenize the prompt
@@ -881,10 +933,6 @@ int main(int argc, char ** argv) {
881933

882934
std::vector<gpt_vocab::id> embd;
883935

884-
// determine the required inference memory per token:
885-
size_t mem_per_token = 0;
886-
llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
887-
888936
int last_n_size = params.repeat_last_n;
889937
std::vector<gpt_vocab::id> last_n_tokens(last_n_size);
890938
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);

utils.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
4444
std::copy(std::istreambuf_iterator<char>(file),
4545
std::istreambuf_iterator<char>(),
4646
back_inserter(params.prompt));
47-
4847
} else if (arg == "-n" || arg == "--n_predict") {
4948
params.n_predict = std::stoi(argv[++i]);
5049
} else if (arg == "--top_k") {
@@ -72,6 +71,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
7271
params.use_color = true;
7372
} else if (arg == "-r" || arg == "--reverse-prompt") {
7473
params.antiprompt = argv[++i];
74+
} else if (arg == "--perplexity") {
75+
params.perplexity = true;
7576
} else if (arg == "-h" || arg == "--help") {
7677
gpt_print_usage(argc, argv, params);
7778
exit(0);
@@ -109,6 +110,7 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) {
109110
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);
110111
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
111112
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
113+
fprintf(stderr, " --perplexity compute perplexity over the prompt\n");
112114
fprintf(stderr, " -m FNAME, --model FNAME\n");
113115
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
114116
fprintf(stderr, "\n");
@@ -322,9 +324,9 @@ std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, const std::st
322324
while (i > 0) {
323325
gpt_vocab::id token_id = prev[i];
324326
if (token_id == 0) {
325-
// TODO: Return error or something more meaningful
326-
printf("failed to tokenize string!\n");
327-
break;
327+
// TODO: Return error or something more meaningful
328+
printf("failed to tokenize string at %d!\n", i);
329+
break;
328330
}
329331
res.push_back(token_id);
330332
auto token = (*vocab.id_to_token.find(token_id)).second;
@@ -398,7 +400,7 @@ gpt_vocab::id llama_sample_top_p_top_k(
398400
logits_id.push_back(std::make_pair(logits[i]*scale*repeat_penalty, i));
399401
} else {
400402
logits_id.push_back(std::make_pair(logits[i]*scale/repeat_penalty, i));
401-
}
403+
}
402404
} else {
403405
logits_id.push_back(std::make_pair(logits[i]*scale, i));
404406
}
@@ -527,7 +529,7 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t
527529

528530
char * pdst = (char *) dst;
529531

530-
for (int j = 0; j < n; j += k) {
532+
for (int j = 0; j < n; j += k) {
531533
uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs);
532534
uint8_t * pm = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + sizeof(float));
533535
uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + 2*sizeof(float));
@@ -550,7 +552,7 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t
550552

551553
*(float *) pd = d;
552554
*(float *) pm = min;
553-
pd += bs;
555+
pd += bs;
554556
pm += bs;
555557

556558
for (int l = 0; l < qk; l += 2) {

utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ struct gpt_params {
3535
bool interactive = false; // interactive mode
3636
bool interactive_start = false; // reverse prompt immediately
3737
std::string antiprompt = ""; // string upon seeing which more user input is prompted
38+
39+
bool perplexity = false;
3840
};
3941

4042
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);

0 commit comments

Comments
 (0)