Skip to content

Commit 8f70dcb

Browse files
committed
perplexity : make Winogrande work as it does on master
The problems with the Winogrande implementation will need to be fixed in a separate PR to ease review.
1 parent d04cfaf commit 8f70dcb

File tree

1 file changed

+36
-51
lines changed

1 file changed

+36
-51
lines changed

examples/perplexity/perplexity.cpp

+36-51
Original file line numberDiff line numberDiff line change
@@ -999,6 +999,8 @@ struct winogrande_entry {
999999
size_t i_logits;
10001000
size_t common_prefix;
10011001
size_t required_tokens;
1002+
size_t n_base1; // number of tokens for context + choice 1
1003+
size_t n_base2; // number of tokens for context + choice 2
10021004
std::vector<llama_token> seq_tokens[2];
10031005
};
10041006

@@ -1038,38 +1040,6 @@ static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string&
10381040
auto choice2 = line.substr(comma_pos[2]+1, comma_pos[3] - comma_pos[2] - 1);
10391041
auto answer = line.substr(comma_pos[3]+1, line.size() - comma_pos[3] - 1);
10401042
auto index = line.substr(0, comma_pos[0]);
1041-
if ('a' <= sentence[0] && sentence[0] <= 'z') {
1042-
// make the first letter a capital letter
1043-
sentence[0] -= 'a' - 'A';
1044-
}
1045-
for (int i = 0; i < (int) sentence.size() - 1; ++i) {
1046-
// trim repeated spaces and spaces before punctuation
1047-
if (sentence[i] == ' ') {
1048-
char next = sentence[i+1];
1049-
if (next == ' ' || next == ',' || next == '.' || next == '\'') {
1050-
char r[2] = { next, 0 };
1051-
sentence.replace(i, 2, r);
1052-
--i; // stay at the same index for repeated spaces
1053-
}
1054-
} else if (sentence[i] == ',' || sentence[i] == '.') {
1055-
if (sentence[i] == sentence[i+1]) {
1056-
// trim repeated punctuation (forward to work at the end of sentences)
1057-
char r[2] = { sentence[i], 0 };
1058-
sentence.replace(i, 2, r);
1059-
--i; // same index to then run the other checks on that punctuation
1060-
} else if (0 < i && sentence[i-1] == sentence[i]) {
1061-
// trim repeated punctuation (looks back to work with the space trim)
1062-
char r[2] = { sentence[i], 0 };
1063-
sentence.replace(i-1, 2, r);
1064-
i -= 2; // go back because content was shifted
1065-
} else if (sentence[i+1] != ' ') {
1066-
// add missing space after punctuation
1067-
// (since the loop stops before the end, this adds no trailing space)
1068-
char r[3] = { sentence[i], ' ', 0 };
1069-
sentence.replace(i, 1, r);
1070-
}
1071-
}
1072-
}
10731043
int where = 0;
10741044
for ( ; where < int(sentence.size()); ++where) {
10751045
if (sentence[where] == '_') break;
@@ -1106,6 +1076,8 @@ static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string&
11061076
*/
11071077
static void winogrande_score(llama_context * ctx, const gpt_params & params) {
11081078

1079+
constexpr int k_min_trailing_ctx = 3;
1080+
11091081
auto data = load_winogrande_from_csv(params.prompt);
11101082
if (data.empty()) {
11111083
fprintf(stderr, "%s: no tasks\n", __func__);
@@ -1150,11 +1122,13 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
11501122
task.common_prefix++;
11511123
}
11521124

1125+
// TODO: the last token of each of the sequences don't need to be evaluated
11531126
task.required_tokens = task.common_prefix +
11541127
task.seq_tokens[0].size() - task.common_prefix +
1155-
task.seq_tokens[1].size() - task.common_prefix
1156-
// the last tokens don't need to be evaluated
1157-
- 2;
1128+
task.seq_tokens[1].size() - task.common_prefix;
1129+
1130+
task.n_base1 = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos).size();
1131+
task.n_base2 = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos).size();
11581132
}
11591133

11601134
fprintf(stderr, "%s : calculating winogrande score over selected tasks.\n", __func__);
@@ -1201,8 +1175,8 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
12011175
n_logits += 1;
12021176

12031177
for (int s = 0; s < 2; ++s) {
1204-
// end before the last token, no need to predict past the end of the sequences
1205-
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size() - 1; ++i) {
1178+
// TODO: end before the last token, no need to predict past the end of the sequences
1179+
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
12061180
llama_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
12071181
n_logits += 1;
12081182
}
@@ -1234,38 +1208,49 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
12341208
for (size_t i = i0; i < i1; ++i) {
12351209
auto & task = data[i];
12361210

1237-
// start from the end of the common prefix
1238-
size_t li = 0;
1239-
for (size_t j = task.common_prefix-1; j < task.seq_tokens[0].size()-1; ++j) {
1211+
const bool skip_choice =
1212+
task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
1213+
task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;
1214+
1215+
const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
1216+
const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
1217+
size_t li = n_base1 - task.common_prefix;
1218+
for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
12401219
eval_pairs.emplace_back(task.i_logits + li++, task.seq_tokens[0][j+1]);
12411220
}
1242-
// first token of the second choice is predicted by the end of the common prefix
1243-
eval_pairs.emplace_back(task.i_logits, task.seq_tokens[1][task.common_prefix]);
1244-
for (size_t j = task.common_prefix; j < task.seq_tokens[1].size()-1; ++j) {
1221+
const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
1222+
const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
1223+
// FIXME: this uses the wrong first logits when not skipping the choice word
1224+
li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - task.common_prefix;
1225+
for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
12451226
eval_pairs.emplace_back(task.i_logits + li++, task.seq_tokens[1][j+1]);
12461227
}
1247-
if (i < i1 - 1) {
1248-
// make sure all logits have been processed as expected
1249-
GGML_ASSERT(task.i_logits + li == data[i+1].i_logits);
1250-
}
12511228
}
12521229
compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
12531230

12541231
size_t ir = 0;
12551232
for (size_t i = i0; i < i1; ++i) {
12561233
auto & task = data[i];
12571234

1235+
const bool skip_choice =
1236+
task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
1237+
task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;
1238+
12581239
float score_1st = 0;
1259-
for (size_t j = task.common_prefix-1; j < task.seq_tokens[0].size()-1; ++j) {
1240+
const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
1241+
const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
1242+
for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
12601243
score_1st += eval_results[ir++];
12611244
}
1262-
score_1st /= (task.seq_tokens[0].size() - task.common_prefix);
1245+
score_1st /= (task.seq_tokens[0].size() - n_base1 - last_1st);
12631246

12641247
float score_2nd = 0;
1265-
for (size_t j = task.common_prefix-1; j < task.seq_tokens[1].size()-1; ++j) {
1248+
const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
1249+
const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
1250+
for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
12661251
score_2nd += eval_results[ir++];
12671252
}
1268-
score_2nd /= (task.seq_tokens[1].size() - task.common_prefix);
1253+
score_2nd /= (task.seq_tokens[1].size() - n_base2 - last_2nd);
12691254

12701255
int result = score_1st > score_2nd ? 1 : 2;
12711256

0 commit comments

Comments
 (0)