@@ -458,7 +458,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
458
458
return true ;
459
459
}
460
460
461
- static void hellaswag_compute_logprobs (const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
461
+ static void compute_logprobs (const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
462
462
const std::vector<std::pair<size_t , llama_token>>& eval_pairs, std::vector<float >& eval_results) {
463
463
constexpr int k_token_chunk = 4 ;
464
464
if (eval_results.size () != eval_pairs.size ()) {
@@ -700,7 +700,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
700
700
}
701
701
}
702
702
// Then we do the actual calculation
703
- hellaswag_compute_logprobs (batch_logits.data (), n_vocab, workers, eval_pairs, eval_results);
703
+ compute_logprobs (batch_logits.data (), n_vocab, workers, eval_pairs, eval_results);
704
704
705
705
size_t ir = 0 ;
706
706
@@ -906,6 +906,10 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
906
906
std::vector<float > tok_logits (n_vocab);
907
907
std::vector<float > batch_logits (n_vocab*n_ctx);
908
908
909
+ std::vector<std::pair<size_t , llama_token>> eval_pairs;
910
+ std::vector<float > eval_results;
911
+ std::vector<std::thread> workers (std::thread::hardware_concurrency ());
912
+
909
913
int n_correct = 0 ;
910
914
int n_done = 0 ;
911
915
@@ -956,61 +960,52 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
956
960
return ;
957
961
}
958
962
963
+ eval_pairs.clear ();
959
964
for (size_t i = i0; i < i1; ++i) {
960
965
auto & task = data[i];
961
966
962
967
const bool skip_choice =
963
968
task.seq_tokens [0 ].size () - task.common_prefix > k_min_trailing_ctx &&
964
969
task.seq_tokens [1 ].size () - task.common_prefix > k_min_trailing_ctx;
965
970
966
- float score_1st = 0 ;
967
- bool is_nan_1st = false ;
968
971
const auto & n_base1 = skip_choice ? task.n_base1 : task.common_prefix ;
969
972
const int last_1st = task.seq_tokens [0 ].size () - n_base1 > 1 ? 1 : 0 ;
970
973
size_t li = n_base1 - 1 ;
971
974
for (size_t j = n_base1-1 ; j < task.seq_tokens [0 ].size ()-1 -last_1st; ++j) {
972
- std::memcpy (tok_logits.data (), batch_logits.data () + n_vocab*(task.i_batch + li++), n_vocab*sizeof (float ));
973
- const float prob = softmax (tok_logits)[task.seq_tokens [0 ][j+1 ]];
974
- if (std::isnan (prob) || !prob) {
975
- fprintf (stderr, " %s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n " , __func__,
976
- prob, j, (task.first + task.choices [0 ] + task.second ).c_str (), n_base1);
977
- is_nan_1st = true ;
978
- break ;
979
- }
980
- score_1st += std::log (prob);
975
+ eval_pairs.push_back (std::make_pair (task.i_batch + li++, task.seq_tokens [0 ][j+1 ]));
981
976
}
982
- score_1st /= (task.seq_tokens [0 ].size () - n_base1 - last_1st);
983
-
984
- float score_2nd = 0 ;
985
- bool is_nan_2nd = false ;
986
977
const auto & n_base2 = skip_choice ? task.n_base2 : task.common_prefix ;
987
978
const int last_2nd = task.seq_tokens [1 ].size () - n_base2 > 1 ? 1 : 0 ;
988
979
li = task.seq_tokens [0 ].size () - task.common_prefix + n_base2 - 1 ;
989
980
for (size_t j = n_base2-1 ; j < task.seq_tokens [1 ].size ()-1 -last_2nd; ++j) {
990
- std::memcpy (tok_logits.data (), batch_logits.data () + n_vocab*(task.i_batch + li++), n_vocab*sizeof (float ));
991
- const float prob = softmax (tok_logits)[task.seq_tokens [1 ][j+1 ]];
992
- if (std::isnan (prob) || !prob) {
993
- fprintf (stderr, " %s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n " , __func__,
994
- prob, j, (task.first + task.choices [1 ] + task.second ).c_str (), n_base2);
995
- is_nan_2nd = true ;
996
- break ;
997
- }
998
- score_2nd += std::log (prob);
981
+ eval_pairs.push_back (std::make_pair (task.i_batch + li++, task.seq_tokens [1 ][j+1 ]));
999
982
}
1000
- score_2nd /= (task.seq_tokens [1 ].size () - n_base2 - last_2nd);
983
+ }
984
+ compute_logprobs (batch_logits.data (), n_vocab, workers, eval_pairs, eval_results);
985
+
986
+ size_t ir = 0 ;
987
+ for (size_t i = i0; i < i1; ++i) {
988
+ auto & task = data[i];
989
+
990
+ const bool skip_choice =
991
+ task.seq_tokens [0 ].size () - task.common_prefix > k_min_trailing_ctx &&
992
+ task.seq_tokens [1 ].size () - task.common_prefix > k_min_trailing_ctx;
1001
993
1002
- if (is_nan_1st || is_nan_2nd) {
1003
- continue ;
994
+ float score_1st = 0 ;
995
+ const auto & n_base1 = skip_choice ? task.n_base1 : task.common_prefix ;
996
+ const int last_1st = task.seq_tokens [0 ].size () - n_base1 > 1 ? 1 : 0 ;
997
+ for (size_t j = n_base1-1 ; j < task.seq_tokens [0 ].size ()-1 -last_1st; ++j) {
998
+ score_1st += eval_results[ir++];
1004
999
}
1000
+ score_1st /= (task.seq_tokens [0 ].size () - n_base1 - last_1st);
1005
1001
1006
- if (std::isnan (score_1st) || std::isnan (score_2nd)) {
1007
- printf (" ================== NaN score %g, %g) for:\n " , score_1st, score_2nd);
1008
- printf (" Q1: <%s> - %zu tokens\n " , (task.first + task.choices [0 ] + task.second ).c_str (), task.seq_tokens [0 ].size ());
1009
- printf (" Q2: <%s> - %zu tokens\n " , (task.first + task.choices [1 ] + task.second ).c_str (), task.seq_tokens [1 ].size ());
1010
- printf (" B : <%s> - %zu tokens\n " , task.first .c_str (), task.common_prefix );
1011
- printf (" base_1 has %zu tokens, base_2 has %zu tokens, skip_choice = %d\n " , n_base1, n_base2, skip_choice);
1012
- continue ;
1002
+ float score_2nd = 0 ;
1003
+ const auto & n_base2 = skip_choice ? task.n_base2 : task.common_prefix ;
1004
+ const int last_2nd = task.seq_tokens [1 ].size () - n_base2 > 1 ? 1 : 0 ;
1005
+ for (size_t j = n_base2-1 ; j < task.seq_tokens [1 ].size ()-1 -last_2nd; ++j) {
1006
+ score_2nd += eval_results[ir++];
1013
1007
}
1008
+ score_2nd /= (task.seq_tokens [1 ].size () - n_base2 - last_2nd);
1014
1009
1015
1010
int result = score_1st > score_2nd ? 1 : 2 ;
1016
1011
@@ -1019,7 +1014,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
1019
1014
}
1020
1015
++n_done;
1021
1016
1022
- // Print the accumulated accuracy mean x 100
1017
+ // print the accumulated accuracy mean x 100
1023
1018
printf (" %zu\t %.4lf\t %10.6f %10.6f %d %d\n " , i+1 , 100.0 * n_correct/n_done, score_1st, score_2nd, result, task.answer );
1024
1019
fflush (stdout);
1025
1020
}
0 commit comments