Skip to content

Commit 74eebc6

Browse files
committed
What if we do something crazy like add layers instead of removing them?
1 parent a0c2f5c commit 74eebc6

File tree

2 files changed

+227
-63
lines changed

2 files changed

+227
-63
lines changed

examples/perplexity/perplexity.cpp

+48-24
Original file line numberDiff line numberDiff line change
@@ -323,10 +323,19 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
323323

324324
llama_batch batch = llama_batch_get_one(NULL, 0, 0, 0);
325325

326-
const int32_t n_layers = 32; // model layer count
327-
const int test_count = 6; // num perplexity chunks to run for each test
328-
const size_t prune_target = 4; // prune this many of the worst results each pass
329-
// end tunables
326+
// model layer count
327+
const int32_t n_layers = 32;
328+
329+
// num perplexity chunks to run for each test
330+
const int test_count = 4;
331+
332+
// prune this many of the worst results each pass
333+
const size_t prune_target = 2;
334+
335+
// start with all but first/last layers disabled and start adding them back
336+
const bool anti_mode = true;
337+
338+
// **** end tunables ***
330339

331340
// 1 = attn, 2 = mlp, 3 = both
332341
int32_t test_skip_type = 0; // but don't mess with this, it's set automatically.
@@ -340,11 +349,19 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
340349
skip_types.resize(n_layers);
341350
std::fill(skip_types.begin(), skip_types.end(), 0);
342351
std::vector<std::tuple<int32_t, int32_t, double>> pass_results;
343-
std::vector<int32_t> worsts;
344-
worsts.resize(n_layers);
345-
std::fill(worsts.begin(), worsts.end(), 0);
352+
std::vector<int32_t> extremes;
353+
extremes.resize(n_layers);
354+
std::fill(extremes.begin(), extremes.end(), 0);
355+
if (anti_mode) {
356+
// No pointing in starting with first/last layer disabled.
357+
skip_types[0] = 15;
358+
skip_types[n_layers - 1] = 15;
359+
skips.push_back(0); skips.push_back(0 + n_layers);
360+
skips.push_back(n_layers - 1); skips.push_back(n_layers - 1 + n_layers);
361+
}
346362
int32_t curr_best_layer = -1, curr_best_type = 0;
347363
double curr_best_ppl = -1, ref_ppl = -1;
364+
const int32_t mask = anti_mode ? 3 : 0;
348365

349366
int count = 0;
350367
double nll = 0.0;
@@ -372,35 +389,40 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
372389
}
373390
if (skip_layer >= n_layers) {
374391
if (curr_best_layer == -1) break;
375-
if (pass_results.size() >= prune_target * 2) {
392+
if (prune_target > 0 && pass_results.size() >= prune_target * 2) {
376393
std::sort(pass_results.begin(), pass_results.end(),
377394
[](const std::tuple<int32_t, int32_t, double> & a, const std::tuple<int32_t, int32_t, double> & b) {
395+
if (anti_mode) return std::get<2>(b) > std::get<2>(a);
378396
return std::get<2>(a) > std::get<2>(b);
379397
}
380398
);
381399
const size_t num_prune = std::min(pass_results.size(), prune_target);
382-
for (size_t temp = 0; temp < num_prune; temp++) {
400+
for (size_t temp = 0, pruned = 0; temp < pass_results.size(); temp++) {
383401
int32_t lidx = std::get<0>(pass_results[temp]);
384402
if (lidx == curr_best_layer && std::get<1>(pass_results[temp]) == curr_best_type) continue;
385-
worsts[lidx] |= std::get<1>(pass_results[temp]);
386-
printf("\nPrune[%zu]: %d (%d) - %.2f\n", temp, lidx, std::get<1>(pass_results[temp]), std::get<2>(pass_results[temp]));
403+
extremes[lidx] |= std::get<1>(pass_results[temp]);
404+
printf("\nPrune[%zu]: %d (%d) - %.2f\n", pruned + 1, lidx,
405+
std::get<1>(pass_results[temp]), std::get<2>(pass_results[temp]));
406+
if (anti_mode) {
407+
skip_types[lidx] |= std::get<1>(pass_results[temp]);
408+
skips.push_back(std::get<1>(pass_results[temp]) == 1 ? lidx : -lidx);
409+
}
410+
if (++pruned >= num_prune) break;
387411
}
388412
}
389413
pass_results.clear();
390-
printf("\n\nADD SKIP %c%3d - ppl vs ref %.4f",
414+
printf("\n\nADD %c%3d - ppl vs ref %.4f",
391415
int(label[curr_best_type]), curr_best_layer,
392416
curr_best_ppl - ref_ppl);
393-
if (curr_best_ppl > ref_ppl * 1.75) break;
417+
if (!anti_mode && curr_best_ppl > ref_ppl * 1.75) break;
394418
skip_types[curr_best_layer] += curr_best_type;
395-
if (std::find(skips.begin(), skips.end(), curr_best_layer) == skips.end()) {
396-
skips.push_back(curr_best_layer);
397-
}
419+
skips.push_back(curr_best_type == 1 ? curr_best_layer : curr_best_layer + n_layers);
398420
curr_best_layer = -1;
399421
curr_best_ppl = -1;
400422
curr_best_type = 0;
401423
skip_layer = n_layers;
402424
for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) {
403-
skip_types[new_sl] = (skip_types[new_sl] & 3) | (worsts[new_sl] << 2);
425+
skip_types[new_sl] = (skip_types[new_sl] & 3) | (extremes[new_sl] << 2);
404426
}
405427
for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) {
406428
int32_t curr_skipped = (skip_types[new_sl] >> 2) | (skip_types[new_sl] & 3);
@@ -420,16 +442,18 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
420442
logit_history.clear();
421443
prob_history.clear();
422444

445+
int alive = 0;
423446
for (int32_t i = 0; i < n_layers; i++) {
424-
layers[i] = (skip_types[i] & 3) | (i == skip_layer ? test_skip_type : 0);
447+
layers[i] = mask ^ ((skip_types[i] & 3) | (i == skip_layer ? test_skip_type : 0));
448+
alive += !(layers[i] & 1) + !(layers[i] & 2);
425449
}
426450
layers[n_layers] = -1;
427451
printf("\nTEST %c%3d + [", int(label[test_skip_type]), skip_layer);
428-
for (const auto l : skips) {
429-
printf("%c%d, ", int(label[skip_types[l] & 3]), l);
452+
for (auto l : skips) {
453+
printf("%c%d, ", int(label[skip_types[l % n_layers] & 3]), l % n_layers);
430454
}
431-
printf("] - len: %3zu, best:(%c%3d @ %.3f), last took %.2f sec\n",
432-
skips.size() + 1,
455+
printf("] - live: %3d/%3d, best:(%c%3d @ %.3f), last took %.2f sec\n",
456+
alive, n_layers * 2,
433457
int(label[curr_best_type]), curr_best_layer,
434458
curr_best_ppl != -1 ? curr_best_ppl - ref_ppl : 0,
435459
test_t_total);
@@ -477,7 +501,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
477501

478502
const auto t_end = std::chrono::high_resolution_clock::now();
479503

480-
if (i == 0 && skip_layer < 0 && skips.empty()) {
504+
if (i == 0 && skip_layer < 0 && ref_ppl < 0) {
481505
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
482506
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
483507
int total_seconds = (int)(t_total * n_chunk);
@@ -516,7 +540,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
516540
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
517541
}
518542
fflush(stdout);
519-
if (skip_layer >= 0 && (i + 1 == test_count || (i > 1 && ppl > ref_ppl * 3))) {
543+
if (skip_layer >= 0 && (i + 1 == test_count || (i > 1 && ppl > ref_ppl * 30))) {
520544
i = test_count - 1;
521545
skip_types[skip_layer] |= test_skip_type << 2;
522546
if (curr_best_layer == -1 || ppl < curr_best_ppl) {

0 commit comments

Comments
 (0)