@@ -27,20 +27,27 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
27
27
28
28
int count = 0 ;
29
29
int seq_count = tokens.size () / params.n_ctx ;
30
+ int n_vocab = llama_n_vocab (ctx);
30
31
31
32
double nll = 0.0 ;
32
-
33
- fprintf (stderr, " %s : calculating perplexity over %d chunks\n " , __func__, seq_count);
33
+ fprintf (stderr, " %s : calculating perplexity over %d chunks, batch_size=%d\n " , __func__, seq_count, params.n_batch );
34
34
35
35
for (int i = 0 ; i < seq_count; ++i) {
36
36
int start = i * params.n_ctx ;
37
- int end = start + params.n_ctx - 1 ; // TODO: this is not optimal, e.g. it makes the batch 511 instead of 512
38
- // it is better to always be power of 2 for better performance
39
- std::vector<llama_token> embd (tokens.begin () + start, tokens.begin () + end);
37
+ int end = start + params.n_ctx ;
38
+
39
+ std::vector<float > logits;
40
+ int num_batches = (params.n_ctx + params.n_batch - 1 ) / params.n_batch ;
40
41
auto start_t = std::chrono::high_resolution_clock::now ();
41
- if (llama_eval (ctx, embd.data (), embd.size (), 0 , params.n_threads )) {
42
- fprintf (stderr, " %s : failed to eval\n " , __func__);
43
- return ;
42
+ for (int j = 0 ; j < num_batches; ++j) {
43
+ int batch_start = start + j * params.n_batch ;
44
+ int batch_size = std::min (end - batch_start, params.n_batch );
45
+ if (llama_eval (ctx, tokens.data () + batch_start, batch_size, j * params.n_batch , params.n_threads )) {
46
+ fprintf (stderr, " %s : failed to eval\n " , __func__);
47
+ return ;
48
+ }
49
+ auto batch_logits = llama_get_logits (ctx);
50
+ logits.insert (logits.end (), batch_logits, batch_logits + batch_size * n_vocab);
44
51
}
45
52
auto end_t = std::chrono::high_resolution_clock::now ();
46
53
if (i == 0 ) {
@@ -59,15 +66,12 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
59
66
// Example, we have a context window of 512, we will compute perplexity for each of the
60
67
// last 256 tokens. Then, we split the input up into context window size chunks to
61
68
// process the entire prompt.
62
-
63
- auto logits = llama_get_logits (ctx);
64
- for (int j = params.n_ctx / 2 ; j < params.n_ctx - 1 ; ++j) {
69
+ for (int j = std::min (512 , params.n_ctx / 2 ); j < params.n_ctx - 1 ; ++j) {
65
70
// Calculate probability of next token, given the previous ones.
66
- int n_vocab = llama_n_vocab (ctx);
67
71
std::vector<float > tok_logits (
68
- logits + j * n_vocab,
69
- logits + (j + 1 ) * n_vocab);
70
- const float prob = softmax (tok_logits)[tokens[start + j + 1 ]];
72
+ logits. begin () + j * n_vocab,
73
+ logits. begin () + (j + 1 ) * n_vocab);
74
+ float prob = softmax (tok_logits)[tokens[start + j + 1 ]];
71
75
nll += -std::log (prob);
72
76
++count;
73
77
}
@@ -82,11 +86,13 @@ int main(int argc, char ** argv) {
82
86
gpt_params params;
83
87
params.model = " models/llama-7B/ggml-model.bin" ;
84
88
89
+ params.n_batch = 512 ;
85
90
if (gpt_params_parse (argc, argv, params) == false ) {
86
91
return 1 ;
87
92
}
88
93
89
94
params.perplexity = true ;
95
+ params.n_batch = std::min (params.n_batch , params.n_ctx );
90
96
91
97
if (params.n_ctx > 2048 ) {
92
98
fprintf (stderr, " %s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"
0 commit comments