@@ -999,6 +999,8 @@ struct winogrande_entry {
999
999
size_t i_logits;
1000
1000
size_t common_prefix;
1001
1001
size_t required_tokens;
1002
+ size_t n_base1; // number of tokens for context + choice 1
1003
+ size_t n_base2; // number of tokens for context + choice 2
1002
1004
std::vector<llama_token> seq_tokens[2 ];
1003
1005
};
1004
1006
@@ -1038,38 +1040,6 @@ static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string&
1038
1040
auto choice2 = line.substr (comma_pos[2 ]+1 , comma_pos[3 ] - comma_pos[2 ] - 1 );
1039
1041
auto answer = line.substr (comma_pos[3 ]+1 , line.size () - comma_pos[3 ] - 1 );
1040
1042
auto index = line.substr (0 , comma_pos[0 ]);
1041
- if (' a' <= sentence[0 ] && sentence[0 ] <= ' z' ) {
1042
- // make the first letter a capital letter
1043
- sentence[0 ] -= ' a' - ' A' ;
1044
- }
1045
- for (int i = 0 ; i < (int ) sentence.size () - 1 ; ++i) {
1046
- // trim repeated spaces and spaces before punctuation
1047
- if (sentence[i] == ' ' ) {
1048
- char next = sentence[i+1 ];
1049
- if (next == ' ' || next == ' ,' || next == ' .' || next == ' \' ' ) {
1050
- char r[2 ] = { next, 0 };
1051
- sentence.replace (i, 2 , r);
1052
- --i; // stay at the same index for repeated spaces
1053
- }
1054
- } else if (sentence[i] == ' ,' || sentence[i] == ' .' ) {
1055
- if (sentence[i] == sentence[i+1 ]) {
1056
- // trim repeated punctuation (forward to work at the end of sentences)
1057
- char r[2 ] = { sentence[i], 0 };
1058
- sentence.replace (i, 2 , r);
1059
- --i; // same index to then run the other checks on that punctuation
1060
- } else if (0 < i && sentence[i-1 ] == sentence[i]) {
1061
- // trim repeated punctuation (looks back to work with the space trim)
1062
- char r[2 ] = { sentence[i], 0 };
1063
- sentence.replace (i-1 , 2 , r);
1064
- i -= 2 ; // go back because content was shifted
1065
- } else if (sentence[i+1 ] != ' ' ) {
1066
- // add missing space after punctuation
1067
- // (since the loop stops before the end, this adds no trailing space)
1068
- char r[3 ] = { sentence[i], ' ' , 0 };
1069
- sentence.replace (i, 1 , r);
1070
- }
1071
- }
1072
- }
1073
1043
int where = 0 ;
1074
1044
for ( ; where < int (sentence.size ()); ++where) {
1075
1045
if (sentence[where] == ' _' ) break ;
@@ -1106,6 +1076,8 @@ static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string&
1106
1076
*/
1107
1077
static void winogrande_score (llama_context * ctx, const gpt_params & params) {
1108
1078
1079
+ constexpr int k_min_trailing_ctx = 3 ;
1080
+
1109
1081
auto data = load_winogrande_from_csv (params.prompt );
1110
1082
if (data.empty ()) {
1111
1083
fprintf (stderr, " %s: no tasks\n " , __func__);
@@ -1150,11 +1122,13 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
1150
1122
task.common_prefix ++;
1151
1123
}
1152
1124
1125
+ // TODO: the last token of each of the sequences don't need to be evaluated
1153
1126
task.required_tokens = task.common_prefix +
1154
1127
task.seq_tokens [0 ].size () - task.common_prefix +
1155
- task.seq_tokens [1 ].size () - task.common_prefix
1156
- // the last tokens don't need to be evaluated
1157
- - 2 ;
1128
+ task.seq_tokens [1 ].size () - task.common_prefix ;
1129
+
1130
+ task.n_base1 = ::llama_tokenize (ctx, task.first + task.choices [0 ], add_bos).size ();
1131
+ task.n_base2 = ::llama_tokenize (ctx, task.first + task.choices [1 ], add_bos).size ();
1158
1132
}
1159
1133
1160
1134
fprintf (stderr, " %s : calculating winogrande score over selected tasks.\n " , __func__);
@@ -1201,8 +1175,8 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
1201
1175
n_logits += 1 ;
1202
1176
1203
1177
for (int s = 0 ; s < 2 ; ++s) {
1204
- // end before the last token, no need to predict past the end of the sequences
1205
- for (size_t i = data[i1].common_prefix ; i < data[i1].seq_tokens [s].size () - 1 ; ++i) {
1178
+ // TODO: end before the last token, no need to predict past the end of the sequences
1179
+ for (size_t i = data[i1].common_prefix ; i < data[i1].seq_tokens [s].size (); ++i) {
1206
1180
llama_batch_add (batch, data[i1].seq_tokens [s][i], i, { s0 + s }, true );
1207
1181
n_logits += 1 ;
1208
1182
}
@@ -1234,38 +1208,49 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
1234
1208
for (size_t i = i0; i < i1; ++i) {
1235
1209
auto & task = data[i];
1236
1210
1237
- // start from the end of the common prefix
1238
- size_t li = 0 ;
1239
- for (size_t j = task.common_prefix -1 ; j < task.seq_tokens [0 ].size ()-1 ; ++j) {
1211
+ const bool skip_choice =
1212
+ task.seq_tokens [0 ].size () - task.common_prefix > k_min_trailing_ctx &&
1213
+ task.seq_tokens [1 ].size () - task.common_prefix > k_min_trailing_ctx;
1214
+
1215
+ const auto & n_base1 = skip_choice ? task.n_base1 : task.common_prefix ;
1216
+ const int last_1st = task.seq_tokens [0 ].size () - n_base1 > 1 ? 1 : 0 ;
1217
+ size_t li = n_base1 - task.common_prefix ;
1218
+ for (size_t j = n_base1-1 ; j < task.seq_tokens [0 ].size ()-1 -last_1st; ++j) {
1240
1219
eval_pairs.emplace_back (task.i_logits + li++, task.seq_tokens [0 ][j+1 ]);
1241
1220
}
1242
- // first token of the second choice is predicted by the end of the common prefix
1243
- eval_pairs.emplace_back (task.i_logits , task.seq_tokens [1 ][task.common_prefix ]);
1244
- for (size_t j = task.common_prefix ; j < task.seq_tokens [1 ].size ()-1 ; ++j) {
1221
+ const auto & n_base2 = skip_choice ? task.n_base2 : task.common_prefix ;
1222
+ const int last_2nd = task.seq_tokens [1 ].size () - n_base2 > 1 ? 1 : 0 ;
1223
+ // FIXME: this uses the wrong first logits when not skipping the choice word
1224
+ li = task.seq_tokens [0 ].size () - task.common_prefix + n_base2 - task.common_prefix ;
1225
+ for (size_t j = n_base2-1 ; j < task.seq_tokens [1 ].size ()-1 -last_2nd; ++j) {
1245
1226
eval_pairs.emplace_back (task.i_logits + li++, task.seq_tokens [1 ][j+1 ]);
1246
1227
}
1247
- if (i < i1 - 1 ) {
1248
- // make sure all logits have been processed as expected
1249
- GGML_ASSERT (task.i_logits + li == data[i+1 ].i_logits );
1250
- }
1251
1228
}
1252
1229
compute_logprobs (batch_logits.data (), n_vocab, workers, eval_pairs, eval_results);
1253
1230
1254
1231
size_t ir = 0 ;
1255
1232
for (size_t i = i0; i < i1; ++i) {
1256
1233
auto & task = data[i];
1257
1234
1235
+ const bool skip_choice =
1236
+ task.seq_tokens [0 ].size () - task.common_prefix > k_min_trailing_ctx &&
1237
+ task.seq_tokens [1 ].size () - task.common_prefix > k_min_trailing_ctx;
1238
+
1258
1239
float score_1st = 0 ;
1259
- for (size_t j = task.common_prefix -1 ; j < task.seq_tokens [0 ].size ()-1 ; ++j) {
1240
+ const auto & n_base1 = skip_choice ? task.n_base1 : task.common_prefix ;
1241
+ const int last_1st = task.seq_tokens [0 ].size () - n_base1 > 1 ? 1 : 0 ;
1242
+ for (size_t j = n_base1-1 ; j < task.seq_tokens [0 ].size ()-1 -last_1st; ++j) {
1260
1243
score_1st += eval_results[ir++];
1261
1244
}
1262
- score_1st /= (task.seq_tokens [0 ].size () - task. common_prefix );
1245
+ score_1st /= (task.seq_tokens [0 ].size () - n_base1 - last_1st );
1263
1246
1264
1247
float score_2nd = 0 ;
1265
- for (size_t j = task.common_prefix -1 ; j < task.seq_tokens [1 ].size ()-1 ; ++j) {
1248
+ const auto & n_base2 = skip_choice ? task.n_base2 : task.common_prefix ;
1249
+ const int last_2nd = task.seq_tokens [1 ].size () - n_base2 > 1 ? 1 : 0 ;
1250
+ for (size_t j = n_base2-1 ; j < task.seq_tokens [1 ].size ()-1 -last_2nd; ++j) {
1266
1251
score_2nd += eval_results[ir++];
1267
1252
}
1268
- score_2nd /= (task.seq_tokens [1 ].size () - task. common_prefix );
1253
+ score_2nd /= (task.seq_tokens [1 ].size () - n_base2 - last_2nd );
1269
1254
1270
1255
int result = score_1st > score_2nd ? 1 : 2 ;
1271
1256
0 commit comments