@@ -323,10 +323,19 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
323
323
324
324
llama_batch batch = llama_batch_get_one (NULL , 0 , 0 , 0 );
325
325
326
- const int32_t n_layers = 32 ; // model layer count
327
- const int test_count = 6 ; // num perplexity chunks to run for each test
328
- const size_t prune_target = 4 ; // prune this many of the worst results each pass
329
- // end tunables
326
+ // model layer count
327
+ const int32_t n_layers = 32 ;
328
+
329
+ // num perplexity chunks to run for each test
330
+ const int test_count = 4 ;
331
+
332
+ // prune this many of the worst results each pass
333
+ const size_t prune_target = 2 ;
334
+
335
+ // start with all but first/last layers disabled and start adding them back
336
+ const bool anti_mode = true ;
337
+
338
+ // **** end tunables ***
330
339
331
340
// 1 = attn, 2 = mlp, 3 = both
332
341
int32_t test_skip_type = 0 ; // but don't mess with this, it's set automatically.
@@ -340,11 +349,19 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
340
349
skip_types.resize (n_layers);
341
350
std::fill (skip_types.begin (), skip_types.end (), 0 );
342
351
std::vector<std::tuple<int32_t , int32_t , double >> pass_results;
343
- std::vector<int32_t > worsts;
344
- worsts.resize (n_layers);
345
- std::fill (worsts.begin (), worsts.end (), 0 );
352
+ std::vector<int32_t > extremes;
353
+ extremes.resize (n_layers);
354
+ std::fill (extremes.begin (), extremes.end (), 0 );
355
+ if (anti_mode) {
356
+ // No pointing in starting with first/last layer disabled.
357
+ skip_types[0 ] = 15 ;
358
+ skip_types[n_layers - 1 ] = 15 ;
359
+ skips.push_back (0 ); skips.push_back (0 + n_layers);
360
+ skips.push_back (n_layers - 1 ); skips.push_back (n_layers - 1 + n_layers);
361
+ }
346
362
int32_t curr_best_layer = -1 , curr_best_type = 0 ;
347
363
double curr_best_ppl = -1 , ref_ppl = -1 ;
364
+ const int32_t mask = anti_mode ? 3 : 0 ;
348
365
349
366
int count = 0 ;
350
367
double nll = 0.0 ;
@@ -372,35 +389,40 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
372
389
}
373
390
if (skip_layer >= n_layers) {
374
391
if (curr_best_layer == -1 ) break ;
375
- if (pass_results.size () >= prune_target * 2 ) {
392
+ if (prune_target > 0 && pass_results.size () >= prune_target * 2 ) {
376
393
std::sort (pass_results.begin (), pass_results.end (),
377
394
[](const std::tuple<int32_t , int32_t , double > & a, const std::tuple<int32_t , int32_t , double > & b) {
395
+ if (anti_mode) return std::get<2 >(b) > std::get<2 >(a);
378
396
return std::get<2 >(a) > std::get<2 >(b);
379
397
}
380
398
);
381
399
const size_t num_prune = std::min (pass_results.size (), prune_target);
382
- for (size_t temp = 0 ; temp < num_prune ; temp++) {
400
+ for (size_t temp = 0 , pruned = 0 ; temp < pass_results. size () ; temp++) {
383
401
int32_t lidx = std::get<0 >(pass_results[temp]);
384
402
if (lidx == curr_best_layer && std::get<1 >(pass_results[temp]) == curr_best_type) continue ;
385
- worsts[lidx] |= std::get<1 >(pass_results[temp]);
386
- printf (" \n Prune[%zu]: %d (%d) - %.2f\n " , temp, lidx, std::get<1 >(pass_results[temp]), std::get<2 >(pass_results[temp]));
403
+ extremes[lidx] |= std::get<1 >(pass_results[temp]);
404
+ printf (" \n Prune[%zu]: %d (%d) - %.2f\n " , pruned + 1 , lidx,
405
+ std::get<1 >(pass_results[temp]), std::get<2 >(pass_results[temp]));
406
+ if (anti_mode) {
407
+ skip_types[lidx] |= std::get<1 >(pass_results[temp]);
408
+ skips.push_back (std::get<1 >(pass_results[temp]) == 1 ? lidx : -lidx);
409
+ }
410
+ if (++pruned >= num_prune) break ;
387
411
}
388
412
}
389
413
pass_results.clear ();
390
- printf (" \n\n ADD SKIP %c%3d - ppl vs ref %.4f" ,
414
+ printf (" \n\n ADD %c%3d - ppl vs ref %.4f" ,
391
415
int (label[curr_best_type]), curr_best_layer,
392
416
curr_best_ppl - ref_ppl);
393
- if (curr_best_ppl > ref_ppl * 1.75 ) break ;
417
+ if (!anti_mode && curr_best_ppl > ref_ppl * 1.75 ) break ;
394
418
skip_types[curr_best_layer] += curr_best_type;
395
- if (std::find (skips.begin (), skips.end (), curr_best_layer) == skips.end ()) {
396
- skips.push_back (curr_best_layer);
397
- }
419
+ skips.push_back (curr_best_type == 1 ? curr_best_layer : curr_best_layer + n_layers);
398
420
curr_best_layer = -1 ;
399
421
curr_best_ppl = -1 ;
400
422
curr_best_type = 0 ;
401
423
skip_layer = n_layers;
402
424
for (int32_t new_sl = 0 ; new_sl < n_layers; new_sl++) {
403
- skip_types[new_sl] = (skip_types[new_sl] & 3 ) | (worsts [new_sl] << 2 );
425
+ skip_types[new_sl] = (skip_types[new_sl] & 3 ) | (extremes [new_sl] << 2 );
404
426
}
405
427
for (int32_t new_sl = 0 ; new_sl < n_layers; new_sl++) {
406
428
int32_t curr_skipped = (skip_types[new_sl] >> 2 ) | (skip_types[new_sl] & 3 );
@@ -420,16 +442,18 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
420
442
logit_history.clear ();
421
443
prob_history.clear ();
422
444
445
+ int alive = 0 ;
423
446
for (int32_t i = 0 ; i < n_layers; i++) {
424
- layers[i] = (skip_types[i] & 3 ) | (i == skip_layer ? test_skip_type : 0 );
447
+ layers[i] = mask ^ ((skip_types[i] & 3 ) | (i == skip_layer ? test_skip_type : 0 ));
448
+ alive += !(layers[i] & 1 ) + !(layers[i] & 2 );
425
449
}
426
450
layers[n_layers] = -1 ;
427
451
printf (" \n TEST %c%3d + [" , int (label[test_skip_type]), skip_layer);
428
- for (const auto l : skips) {
429
- printf (" %c%d, " , int (label[skip_types[l] & 3 ]), l);
452
+ for (auto l : skips) {
453
+ printf (" %c%d, " , int (label[skip_types[l % n_layers ] & 3 ]), l % n_layers );
430
454
}
431
- printf (" ] - len : %3zu , best:(%c%3d @ %.3f), last took %.2f sec\n " ,
432
- skips. size () + 1 ,
455
+ printf (" ] - live : %3d/%3d , best:(%c%3d @ %.3f), last took %.2f sec\n " ,
456
+ alive, n_layers * 2 ,
433
457
int (label[curr_best_type]), curr_best_layer,
434
458
curr_best_ppl != -1 ? curr_best_ppl - ref_ppl : 0 ,
435
459
test_t_total);
@@ -477,7 +501,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
477
501
478
502
const auto t_end = std::chrono::high_resolution_clock::now ();
479
503
480
- if (i == 0 && skip_layer < 0 && skips. empty () ) {
504
+ if (i == 0 && skip_layer < 0 && ref_ppl < 0 ) {
481
505
const float t_total = std::chrono::duration<float >(t_end - t_start).count ();
482
506
fprintf (stderr, " %s: %.2f seconds per pass - ETA " , __func__, t_total);
483
507
int total_seconds = (int )(t_total * n_chunk);
@@ -516,7 +540,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
516
540
printf (" %8d %.4lf %4lf %4lf\n " , i*n_ctx, std::exp (nll / count), av, av2);
517
541
}
518
542
fflush (stdout);
519
- if (skip_layer >= 0 && (i + 1 == test_count || (i > 1 && ppl > ref_ppl * 3 ))) {
543
+ if (skip_layer >= 0 && (i + 1 == test_count || (i > 1 && ppl > ref_ppl * 30 ))) {
520
544
i = test_count - 1 ;
521
545
skip_types[skip_layer] |= test_skip_type << 2 ;
522
546
if (curr_best_layer == -1 || ppl < curr_best_ppl) {
0 commit comments