@@ -121,8 +121,23 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
121
121
printf (" \n " );
122
122
}
123
123
124
- void perplexity_lines (llama_context * ctx, const gpt_params & params) {
125
- // Calculates perplexity over each line of the prompt
124
+ void hellaswag_score (llama_context * ctx, const gpt_params & params) {
125
+ // Calculates hellaswag score (acc_norm) from prompt
126
+ //
127
+ // Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
128
+ // All used data fields are preprocessed as in https://github.com/EleutherAI/lm-evaluation-harness/blob/df3da98c5405deafd519c2ddca52bb7c3fe36bef/lm_eval/tasks/hellaswag.py#L62-L68
129
+ //
130
+ // All 10042 tasks should be extracted to keep the results standardized like other implementations.
131
+ //
132
+ // Datafile layout:
133
+ // ['??'] denotes json fields
134
+ // 6 lines per task:
135
+ // ['activity_label'] + ": " +['ctx'] - The first part of the query, the context
136
+ // ['label'] - The index the best common sense ending aka gold ending
137
+ // ['endings'][0] - Endings added to the first part of the query
138
+ // ['endings'][1]
139
+ // ['endings'][2]
140
+ // ['endings'][3]
126
141
127
142
std::vector<std::string> prompt_lines;
128
143
std::istringstream strstream (params.prompt );
@@ -132,63 +147,149 @@ void perplexity_lines(llama_context * ctx, const gpt_params & params) {
132
147
prompt_lines.push_back (line);
133
148
}
134
149
135
- const int n_vocab = llama_n_vocab (ctx);
150
+ if ( prompt_lines.size () % 6 != 0 ) {
151
+ fprintf (stderr, " %s : number of lines in prompt not a multiple of 6.\n " , __func__);
152
+ return ;
153
+ }
136
154
137
- int counttotal = 0 ;
138
- size_t n_lines = prompt_lines. size ( );
155
+ size_t hs_task_count = prompt_lines. size ()/ 6 ;
156
+ fprintf (stderr, " %s : loaded %lu tasks from prompt. \n " , __func__, hs_task_count );
139
157
140
- double nll = 0.0 ;
158
+ // This is needed as usual for LLaMA models
159
+ bool prepend_bos = true ;
160
+
161
+ // Number of tasks to use when computing the score
162
+ if ( params.hellaswag_tasks < hs_task_count ) {
163
+ hs_task_count = params.hellaswag_tasks ;
164
+ }
141
165
142
- fprintf (stderr, " %s: calculating perplexity over %lu lines\n " , __func__, n_lines);
166
+ // The tasks should be randomized so the score stabilizes quickly.
167
+ bool randomize_tasks = true ;
143
168
144
- printf (" \n Line\t PPL line\t PPL cumulative\n " );
169
+ // The random seed should not impact the final result if the computation is done over enough tasks, so kept hardcoded for now
170
+ std::mt19937 rng (1 );
145
171
146
- for (size_t i = 0 ; i < n_lines; ++i) {
172
+ // Dataholder for hellaswag tasks
173
+ struct hs_data_t {
174
+ std::string context;
175
+ size_t gold_ending_idx;
176
+ std::string ending[4 ];
177
+ size_t ending_logprob_count[4 ];
178
+ double ending_logprob[4 ];
179
+ };
147
180
148
- // Tokenize and insert BOS at start
149
- std::vector<int > batch_embd = ::llama_tokenize (ctx, prompt_lines[i], true );
181
+ fprintf (stderr, " %s : selecting %lu %s tasks.\n " , __func__, hs_task_count, (randomize_tasks?" randomized" :" the first" ) );
150
182
151
- size_t batch_size = batch_embd.size ();
183
+ // Select and read data from prompt lines
184
+ hs_data_t *hs_data = new hs_data_t [hs_task_count];
185
+ for (size_t i=0 ; i < hs_task_count; i++) {
186
+ size_t idx = i;
152
187
153
- // Stop if line is too long
154
- if ( batch_size > ( size_t )params. n_ctx ) {
155
- fprintf (stderr, " %s : tokens in line %lu > n_ctxl \n " , __func__, i) ;
156
- return ;
188
+ // Select a random example of those left in the prompt
189
+ if (randomize_tasks ) {
190
+ std::uniform_int_distribution< size_t > dist ( 0 , prompt_lines. size ()/ 6 - 1 ) ;
191
+ idx = dist (rng) ;
157
192
}
158
193
159
- if (llama_eval (ctx, batch_embd.data (), batch_size, 0 , params.n_threads )) {
160
- fprintf (stderr, " %s : failed to eval\n " , __func__);
161
- return ;
194
+ hs_data[i].context = prompt_lines[idx*6 ];
195
+ hs_data[i].gold_ending_idx = std::stoi ( prompt_lines[idx*6 +1 ] );
196
+ for (size_t j=0 ; j < 4 ; j++) {
197
+ hs_data[i].ending [j] = " " + prompt_lines[idx*6 +2 +j];
162
198
}
163
199
164
- const auto batch_logits = llama_get_logits (ctx);
165
- std::vector<float > logits;
166
- logits.insert (logits.end (), batch_logits, batch_logits + batch_size * n_vocab);
200
+ // Delete the selected random example from the prompt
201
+ if (randomize_tasks) {
202
+ prompt_lines.erase ( std::next (prompt_lines.begin (),idx*6 ) , std::next (prompt_lines.begin (),idx*6 +6 ) );
203
+ }
204
+ }
167
205
168
- double nllline = 0.0 ;
169
- int countline = 0 ;
206
+ fprintf (stderr, " %s : calculating hellaswag score over selected tasks. \n " , __func__) ;
207
+ printf ( " \n task \t acc_norm \n " ) ;
170
208
171
- // Perplexity over second half of the line
172
- for (size_t j = batch_size/2 ; j < batch_size - 1 ; ++j) {
173
- // Calculate probability of next token, given the previous ones.
174
- const std::vector<float > tok_logits (
175
- logits.begin () + (j + 0 ) * n_vocab,
176
- logits.begin () + (j + 1 ) * n_vocab);
209
+ double acc = 0 .0f ;
210
+ const int n_vocab = llama_n_vocab (ctx);
211
+
212
+ for (size_t task_idx = 0 ; task_idx < hs_task_count; task_idx++) {
213
+
214
+ // Tokenize the context to count tokens
215
+ std::vector<int > context_embd = ::llama_tokenize (ctx, hs_data[task_idx].context , prepend_bos);
216
+ size_t context_size = context_embd.size ();
217
+
218
+ for (size_t ending_idx=0 ;ending_idx<4 ;ending_idx++) {
219
+
220
+ // Tokenize the query
221
+ std::vector<int > query_embd = ::llama_tokenize (ctx, hs_data[task_idx].context + hs_data[task_idx].ending [ending_idx], prepend_bos);
222
+ size_t query_size = query_embd.size ();
223
+
224
+ // Stop if query wont fit the ctx window
225
+ if (query_size > (size_t )params.n_ctx ) {
226
+ fprintf (stderr, " %s : number of tokens in query %lu > n_ctxl\n " , __func__, query_size);
227
+ return ;
228
+ }
177
229
178
- const float prob = softmax (tok_logits)[batch_embd[ j + 1 ]];
230
+ // Speedup small evaluations by evaluating atleast 32 tokens
231
+ if (query_size < 32 ) {
232
+ query_embd.resize (32 );
233
+ }
234
+
235
+ // Evaluate the query
236
+ if (llama_eval (ctx, query_embd.data (), query_embd.size (), 0 , params.n_threads )) {
237
+ fprintf (stderr, " %s : failed to eval\n " , __func__);
238
+ return ;
239
+ }
240
+
241
+ const auto query_logits = llama_get_logits (ctx);
242
+ std::vector<float > logits;
243
+ logits.insert (logits.end (), query_logits, query_logits + query_size * n_vocab);
244
+
245
+ hs_data[task_idx].ending_logprob_count [ending_idx] = 0 ;
246
+ hs_data[task_idx].ending_logprob [ending_idx] = 0 .0f ;
247
+
248
+ // Calculate the logprobs over the ending
249
+ for (size_t j = context_size-1 ; j < query_size - 1 ; j++) {
250
+ // Calculate probability of next token, given the previous ones.
251
+ const std::vector<float > tok_logits (
252
+ logits.begin () + (j + 0 ) * n_vocab,
253
+ logits.begin () + (j + 1 ) * n_vocab);
254
+
255
+ const float prob = softmax (tok_logits)[query_embd[ j + 1 ]];
256
+
257
+ hs_data[task_idx].ending_logprob [ending_idx] += std::log (prob);
258
+ hs_data[task_idx].ending_logprob_count [ending_idx]++;
259
+ }
260
+
261
+ // Calculate the mean token logprob for acc_norm
262
+ hs_data[task_idx].ending_logprob [ending_idx] /= hs_data[task_idx].ending_logprob_count [ending_idx];
263
+
264
+
265
+ // printf("task %lu, ending %lu, whole_len %lu, context_len %lu, ending_logprob_count %lu, ending_logprob %.4f\n",
266
+ // task_idx,ending_idx,whole_size,context_size, hs_data[task_idx].ending_logprob_count[ending_idx], hs_data[task_idx].ending_logprob[ending_idx] );
267
+ }
179
268
180
- nllline += -std::log (prob);
181
- ++countline;
269
+ // Find the ending with maximum logprob
270
+ size_t ending_logprob_max_idx = -1 ;
271
+ double ending_logprob_max_val = -INFINITY;
272
+ for (size_t j=0 ; j < 4 ; j++) {
273
+ if (hs_data[task_idx].ending_logprob [j] > ending_logprob_max_val) {
274
+ ending_logprob_max_idx = j;
275
+ ending_logprob_max_val = hs_data[task_idx].ending_logprob [j];
276
+ }
182
277
}
183
278
184
- nll += nllline;
185
- counttotal += countline;
279
+ // printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_data[task_idx].gold_ending_idx);
186
280
187
- // perplexity is e^(average negative log-likelihood)
188
- printf (" %lu\t %.8lf\t %.8lf\n " , i + 1 , std::exp (nllline/countline), std::exp (nll / counttotal) );
281
+ // If the gold ending got the maximum logprobe add one accuracy point
282
+ if (ending_logprob_max_idx == hs_data[task_idx].gold_ending_idx ) {
283
+ acc += 1.0 ;
284
+ }
285
+
286
+ // Print the accumulated accuracy mean x 100
287
+ printf (" %li\t %.8lf\n " ,task_idx+1 , acc/double (task_idx+1 )*100.0 );
189
288
fflush (stdout);
190
289
}
191
290
291
+ delete [] hs_data;
292
+
192
293
printf (" \n " );
193
294
}
194
295
@@ -240,8 +341,8 @@ int main(int argc, char ** argv) {
240
341
params.n_threads , std::thread::hardware_concurrency (), llama_print_system_info ());
241
342
}
242
343
243
- if (params.perplexity_lines ) {
244
- perplexity_lines (ctx, params);
344
+ if (params.hellaswag ) {
345
+ hellaswag_score (ctx, params);
245
346
} else {
246
347
perplexity (ctx, params);
247
348
}
0 commit comments