@@ -323,139 +323,87 @@ static void process_logits(
323
323
}
324
324
325
325
static bool compute_imatrix (llama_context * ctx, const gpt_params & params, bool compute_ppl, int from_chunk) {
326
+ (void )from_chunk;
326
327
327
- const bool add_bos = llama_should_add_bos_token (llama_get_model (ctx));
328
- const int n_ctx = llama_n_ctx (ctx);
329
-
330
- auto tim1 = std::chrono::high_resolution_clock::now ();
331
- fprintf (stderr, " %s: tokenizing the input ..\n " , __func__);
332
-
333
- std::vector<llama_token> tokens = ::llama_tokenize (ctx, params.prompt , add_bos);
334
-
335
- auto tim2 = std::chrono::high_resolution_clock::now ();
336
- fprintf (stderr, " %s: tokenization took %g ms\n " ,__func__,1e-3 *std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count ());
337
-
338
- if (from_chunk > 0 ) {
339
- if (size_t ((from_chunk + 2 )*n_ctx) >= tokens.size ()) {
340
- fprintf (stderr, " %s: there will be not enough tokens left after removing %d chunks\n " , __func__, from_chunk);
341
- return false ;
342
- }
343
- fprintf (stderr, " %s: removing initial %d chunks (%d tokens)\n " , __func__, from_chunk, from_chunk*n_ctx);
344
- tokens.erase (tokens.begin (), tokens.begin () + from_chunk*n_ctx);
345
- }
346
-
347
- if (int (tokens.size ()) < 2 *n_ctx) {
348
- fprintf (stderr, " %s: you need at least %d tokens for a context of %d tokens\n " ,__func__,2 *n_ctx,
349
- n_ctx);
350
- fprintf (stderr, " %s: the data file you provided tokenizes to only %zu tokens\n " ,__func__,tokens.size ());
351
- return false ;
352
- }
353
-
354
- std::vector<float > logit_history;
355
- std::vector<float > prob_history;
356
-
357
- if (compute_ppl) {
358
- logit_history.resize (tokens.size ());
359
- prob_history.resize (tokens.size ());
360
- }
361
-
362
- const int n_chunk_max = tokens.size () / n_ctx;
328
+ std::vector<std::thread> workers (std::thread::hardware_concurrency () - 1 );
363
329
364
- const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min (params.n_chunks , n_chunk_max);
365
- const int n_vocab = llama_n_vocab (llama_get_model (ctx));
366
330
const int n_batch = params.n_batch ;
367
331
368
332
int count = 0 ;
369
333
double nll = 0.0 ;
370
334
double nll2 = 0.0 ;
371
335
372
- fprintf (stderr, " %s: computing over %d chunks with batch_size %d\n " , __func__, n_chunk, n_batch);
373
-
374
- std::vector<std::thread> workers (std::thread::hardware_concurrency () - 1 );
375
-
376
- const int num_batches = (n_ctx + n_batch - 1 ) / n_batch;
377
-
378
- std::vector<float > logits;
379
- if (compute_ppl && num_batches > 1 ) {
380
- logits.reserve ((size_t )n_ctx * n_vocab);
381
- }
382
-
383
- for (int i = 0 ; i < n_chunk; ++i) {
384
- const int start = i * n_ctx;
385
- const int end = start + n_ctx;
336
+ std::vector<llama_token> tokens;
337
+ std::vector<float > logit_history;
338
+ std::vector<float > prob_history;
386
339
387
- std::vector< float > logits ;
340
+ const int n_vocab = llama_n_vocab ( llama_get_model (ctx)) ;
388
341
389
- const auto t_start = std::chrono::high_resolution_clock::now ();
342
+ size_t c_begin = 0 ;
343
+ while (true ) {
344
+ const char * s_begin = " <|im_start|>system\n " ;
345
+ const char * s_assistant = " <|im_start|>assistant\n " ;
346
+ c_begin = params.prompt .find (s_begin, c_begin);
347
+ if (c_begin == std::string::npos) {
348
+ break ;
349
+ }
350
+ size_t c_assistant = params.prompt .find (s_assistant, c_begin);
351
+ if (c_assistant == std::string::npos) {
352
+ break ;
353
+ }
354
+ c_assistant += strlen (s_assistant);
355
+ size_t next_c_begin = params.prompt .find (s_begin, c_assistant);
356
+ auto s_prompt = params.prompt .substr (c_begin, c_assistant - c_begin);
357
+ auto s_response = params.prompt .substr (c_assistant, next_c_begin - c_assistant);
358
+ c_begin += 1 ;
390
359
391
- // clear the KV cache
392
360
llama_kv_cache_clear (ctx);
393
361
394
- for (int j = 0 ; j < num_batches; ++j) {
395
- const int batch_start = start + j * n_batch;
396
- const int batch_size = std::min (end - batch_start, n_batch);
397
-
398
- // save original token and restore it after eval
399
- const auto token_org = tokens[batch_start];
362
+ std::vector<llama_token> s_tokens_prompt = ::llama_tokenize (ctx, s_prompt, false );
363
+ std::vector<llama_token> s_tokens_response = ::llama_tokenize (ctx, s_response, false );
364
+ std::vector<llama_token> s_tokens = s_tokens_prompt;
365
+ s_tokens.insert (s_tokens.end (), s_tokens_response.begin (), s_tokens_response.end ());
366
+ std::vector<float > s_logits;
367
+ std::vector<float > s_logit_history (s_tokens.size (), 0 );
368
+ std::vector<float > s_prob_history (s_tokens.size (), 0 );
400
369
401
- // add BOS token for the first batch of each chunk
402
- if (add_bos && j == 0 ) {
403
- tokens[batch_start] = llama_token_bos (llama_get_model (ctx));
404
- }
370
+ for (int j = 0 ; j < (int (s_tokens.size ()) + n_batch - 1 ) / n_batch; ++j) {
371
+ const int batch_start = j * n_batch;
372
+ const int batch_size = std::min ((int )s_tokens.size () - batch_start, n_batch);
405
373
406
- if (llama_decode (ctx, llama_batch_get_one (tokens .data () + batch_start, batch_size, j * n_batch, 0 ))) {
374
+ if (llama_decode (ctx, llama_batch_get_one (s_tokens .data () + batch_start, batch_size, j * n_batch, 0 ))) {
407
375
fprintf (stderr, " %s : failed to eval\n " , __func__);
408
376
return false ;
409
377
}
410
378
411
- // restore the original token in case it was set to BOS
412
- tokens[batch_start] = token_org;
413
-
414
- if (compute_ppl && num_batches > 1 ) {
415
- const auto * batch_logits = llama_get_logits (ctx);
416
- logits.insert (logits.end (), batch_logits, batch_logits + batch_size * n_vocab);
417
- }
379
+ const auto * batch_logits = llama_get_logits (ctx);
380
+ s_logits.insert (s_logits.end (), batch_logits, batch_logits + batch_size * n_vocab);
418
381
}
419
382
420
- const auto t_end = std::chrono::high_resolution_clock::now ();
421
-
422
- if (i == 0 ) {
423
- const float t_total = std::chrono::duration<float >(t_end - t_start).count ();
424
- fprintf (stderr, " %s: %.2f seconds per pass - ETA " , __func__, t_total);
425
- int total_seconds = (int )(t_total * n_chunk);
426
- if (total_seconds >= 60 *60 ) {
427
- fprintf (stderr, " %d hours " , total_seconds / (60 *60 ));
428
- total_seconds = total_seconds % (60 *60 );
429
- }
430
- fprintf (stderr, " %.2f minutes\n " , total_seconds / 60.0 );
431
- }
383
+ const int first = s_tokens_prompt.size ();
384
+ const float * all_logits = s_logits.data ();
385
+ process_logits (n_vocab, all_logits + first * n_vocab, s_tokens.data () + first, s_tokens_response.size () - 1 ,
386
+ workers, nll, nll2, s_logit_history.data () + first, s_prob_history.data () + first);
387
+ count += s_tokens_response.size () - 1 ;
432
388
433
- if (compute_ppl) {
434
- const int first = n_ctx/2 ;
435
- const auto all_logits = num_batches > 1 ? logits.data () : llama_get_logits (ctx);
436
- process_logits (n_vocab, all_logits + first*n_vocab, tokens.data () + start + first, n_ctx - 1 - first,
437
- workers, nll, nll2, logit_history.data () + start + first, prob_history.data () + start + first);
438
- count += n_ctx - first - 1 ;
389
+ printf (" %.4lf," , std::exp (nll / count));
390
+ fflush (stdout);
439
391
440
- printf (" [%d]%.4lf," , i + 1 , std::exp (nll / count));
441
- fflush (stdout);
392
+ tokens.insert (tokens.end (), s_tokens.begin (), s_tokens.end ());
393
+ logit_history.insert (logit_history.end (), s_logit_history.begin (), s_logit_history.end ());
394
+ prob_history.insert (prob_history.end (), s_prob_history.begin (), s_prob_history.end ());
395
+ }
442
396
443
- logits.clear ();
444
- }
397
+ nll2 /= count;
398
+ nll /= count;
399
+ const double ppl = exp (nll);
400
+ nll2 -= nll * nll;
401
+ if (nll2 > 0 ) {
402
+ nll2 = sqrt (nll2 / (count - 1 ));
403
+ printf (" Final estimate: PPL = %.4lf +/- %.5lf\n " , ppl, nll2 * ppl);
445
404
}
446
- printf (" \n " );
447
-
448
- if (compute_ppl) {
449
- nll2 /= count;
450
- nll /= count;
451
- const double ppl = exp (nll);
452
- nll2 -= nll * nll;
453
- if (nll2 > 0 ) {
454
- nll2 = sqrt (nll2/(count-1 ));
455
- printf (" Final estimate: PPL = %.4lf +/- %.5lf\n " , ppl, nll2*ppl);
456
- } else {
457
- printf (" Unexpected negative standard deviation of log(prob)\n " );
458
- }
405
+ else {
406
+ printf (" Unexpected negative standard deviation of log(prob)\n " );
459
407
}
460
408
461
409
return true ;
0 commit comments