Skip to content

Commit 2db0ca3

Browse files
committed
calc ppl on sakurallm prompt format correctly
1 parent 7c777fc commit 2db0ca3

File tree

2 files changed

+107
-241
lines changed

2 files changed

+107
-241
lines changed

examples/imatrix/imatrix.cpp

+56-108
Original file line numberDiff line numberDiff line change
@@ -323,139 +323,87 @@ static void process_logits(
323323
}
324324

325325
static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl, int from_chunk) {
326+
(void)from_chunk;
326327

327-
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
328-
const int n_ctx = llama_n_ctx(ctx);
329-
330-
auto tim1 = std::chrono::high_resolution_clock::now();
331-
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
332-
333-
std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
334-
335-
auto tim2 = std::chrono::high_resolution_clock::now();
336-
fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
337-
338-
if (from_chunk > 0) {
339-
if (size_t((from_chunk + 2)*n_ctx) >= tokens.size()) {
340-
fprintf(stderr, "%s: there will be not enough tokens left after removing %d chunks\n", __func__, from_chunk);
341-
return false;
342-
}
343-
fprintf(stderr, "%s: removing initial %d chunks (%d tokens)\n", __func__, from_chunk, from_chunk*n_ctx);
344-
tokens.erase(tokens.begin(), tokens.begin() + from_chunk*n_ctx);
345-
}
346-
347-
if (int(tokens.size()) < 2*n_ctx) {
348-
fprintf(stderr, "%s: you need at least %d tokens for a context of %d tokens\n",__func__,2*n_ctx,
349-
n_ctx);
350-
fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
351-
return false;
352-
}
353-
354-
std::vector<float> logit_history;
355-
std::vector<float> prob_history;
356-
357-
if (compute_ppl) {
358-
logit_history.resize(tokens.size());
359-
prob_history.resize(tokens.size());
360-
}
361-
362-
const int n_chunk_max = tokens.size() / n_ctx;
328+
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
363329

364-
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
365-
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
366330
const int n_batch = params.n_batch;
367331

368332
int count = 0;
369333
double nll = 0.0;
370334
double nll2 = 0.0;
371335

372-
fprintf(stderr, "%s: computing over %d chunks with batch_size %d\n", __func__, n_chunk, n_batch);
373-
374-
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
375-
376-
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
377-
378-
std::vector<float> logits;
379-
if (compute_ppl && num_batches > 1) {
380-
logits.reserve((size_t)n_ctx * n_vocab);
381-
}
382-
383-
for (int i = 0; i < n_chunk; ++i) {
384-
const int start = i * n_ctx;
385-
const int end = start + n_ctx;
336+
std::vector<llama_token> tokens;
337+
std::vector<float> logit_history;
338+
std::vector<float> prob_history;
386339

387-
std::vector<float> logits;
340+
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
388341

389-
const auto t_start = std::chrono::high_resolution_clock::now();
342+
size_t c_begin = 0;
343+
while (true) {
344+
const char* s_begin = "<|im_start|>system\n";
345+
const char* s_assistant = "<|im_start|>assistant\n";
346+
c_begin = params.prompt.find(s_begin, c_begin);
347+
if (c_begin == std::string::npos) {
348+
break;
349+
}
350+
size_t c_assistant = params.prompt.find(s_assistant, c_begin);
351+
if (c_assistant == std::string::npos) {
352+
break;
353+
}
354+
c_assistant += strlen(s_assistant);
355+
size_t next_c_begin = params.prompt.find(s_begin, c_assistant);
356+
auto s_prompt = params.prompt.substr(c_begin, c_assistant - c_begin);
357+
auto s_response = params.prompt.substr(c_assistant, next_c_begin - c_assistant);
358+
c_begin += 1;
390359

391-
// clear the KV cache
392360
llama_kv_cache_clear(ctx);
393361

394-
for (int j = 0; j < num_batches; ++j) {
395-
const int batch_start = start + j * n_batch;
396-
const int batch_size = std::min(end - batch_start, n_batch);
397-
398-
// save original token and restore it after eval
399-
const auto token_org = tokens[batch_start];
362+
std::vector<llama_token> s_tokens_prompt = ::llama_tokenize(ctx, s_prompt, false);
363+
std::vector<llama_token> s_tokens_response = ::llama_tokenize(ctx, s_response, false);
364+
std::vector<llama_token> s_tokens = s_tokens_prompt;
365+
s_tokens.insert(s_tokens.end(), s_tokens_response.begin(), s_tokens_response.end());
366+
std::vector<float> s_logits;
367+
std::vector<float> s_logit_history(s_tokens.size(), 0);
368+
std::vector<float> s_prob_history(s_tokens.size(), 0);
400369

401-
// add BOS token for the first batch of each chunk
402-
if (add_bos && j == 0) {
403-
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
404-
}
370+
for (int j = 0; j < (int(s_tokens.size()) + n_batch - 1) / n_batch; ++j) {
371+
const int batch_start = j * n_batch;
372+
const int batch_size = std::min((int)s_tokens.size() - batch_start, n_batch);
405373

406-
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
374+
if (llama_decode(ctx, llama_batch_get_one(s_tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
407375
fprintf(stderr, "%s : failed to eval\n", __func__);
408376
return false;
409377
}
410378

411-
// restore the original token in case it was set to BOS
412-
tokens[batch_start] = token_org;
413-
414-
if (compute_ppl && num_batches > 1) {
415-
const auto * batch_logits = llama_get_logits(ctx);
416-
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
417-
}
379+
const auto* batch_logits = llama_get_logits(ctx);
380+
s_logits.insert(s_logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
418381
}
419382

420-
const auto t_end = std::chrono::high_resolution_clock::now();
421-
422-
if (i == 0) {
423-
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
424-
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
425-
int total_seconds = (int)(t_total * n_chunk);
426-
if (total_seconds >= 60*60) {
427-
fprintf(stderr, "%d hours ", total_seconds / (60*60));
428-
total_seconds = total_seconds % (60*60);
429-
}
430-
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
431-
}
383+
const int first = s_tokens_prompt.size();
384+
const float* all_logits = s_logits.data();
385+
process_logits(n_vocab, all_logits + first * n_vocab, s_tokens.data() + first, s_tokens_response.size() - 1,
386+
workers, nll, nll2, s_logit_history.data() + first, s_prob_history.data() + first);
387+
count += s_tokens_response.size() - 1;
432388

433-
if (compute_ppl) {
434-
const int first = n_ctx/2;
435-
const auto all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
436-
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
437-
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
438-
count += n_ctx - first - 1;
389+
printf(" %.4lf,", std::exp(nll / count));
390+
fflush(stdout);
439391

440-
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
441-
fflush(stdout);
392+
tokens.insert(tokens.end(), s_tokens.begin(), s_tokens.end());
393+
logit_history.insert(logit_history.end(), s_logit_history.begin(), s_logit_history.end());
394+
prob_history.insert(prob_history.end(), s_prob_history.begin(), s_prob_history.end());
395+
}
442396

443-
logits.clear();
444-
}
397+
nll2 /= count;
398+
nll /= count;
399+
const double ppl = exp(nll);
400+
nll2 -= nll * nll;
401+
if (nll2 > 0) {
402+
nll2 = sqrt(nll2 / (count - 1));
403+
printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2 * ppl);
445404
}
446-
printf("\n");
447-
448-
if (compute_ppl) {
449-
nll2 /= count;
450-
nll /= count;
451-
const double ppl = exp(nll);
452-
nll2 -= nll * nll;
453-
if (nll2 > 0) {
454-
nll2 = sqrt(nll2/(count-1));
455-
printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
456-
} else {
457-
printf("Unexpected negative standard deviation of log(prob)\n");
458-
}
405+
else {
406+
printf("Unexpected negative standard deviation of log(prob)\n");
459407
}
460408

461409
return true;

0 commit comments

Comments
 (0)